diff --git a/ai_answers.py b/ai_answers.py index 2badbe7..38e5fa5 100644 --- a/ai_answers.py +++ b/ai_answers.py @@ -1,6 +1,10 @@ -import json, os, logging, base64, time, hashlib, codecs, re +import json, os, logging, base64, time, hashlib, codecs, re, http.client, ssl from urllib.parse import urlparse from searx import network +try: + from searx.network import get_network +except ImportError: + get_network = None # Graceful fallback for test/demo environments from flask import Response, request, abort, jsonify from searx.plugins import Plugin, PluginInfo from searx.result_types import EngineResults @@ -10,6 +14,30 @@ from markupsafe import Markup logger = logging.getLogger(__name__) TOKEN_EXPIRY_SEC = 3600 +STREAM_CHUNK_SIZE = 128 +STREAM_TIMEOUT_SEC = 60 + +def _get_streaming_connection(url: str): + parsed = urlparse(url) + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == 'https' else 80) + path = parsed.path + ('?' + parsed.query if parsed.query else '') + + verify_ssl = True + if get_network is not None: + try: + net = get_network() + verify_ssl = getattr(net, 'verify', True) + except Exception: + pass + + if parsed.scheme == 'https': + ctx = ssl.create_default_context() if verify_ssl else ssl._create_unverified_context() + conn = http.client.HTTPSConnection(host, port, timeout=STREAM_TIMEOUT_SEC, context=ctx) + else: + conn = http.client.HTTPConnection(host, port, timeout=STREAM_TIMEOUT_SEC) + + return conn, path @@ -731,22 +759,23 @@ class SXNGPlugin(Plugin): else: url = f"{self.endpoint_url}?key={self.api_key}" + conn = None try: - payload = {"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"maxOutputTokens": self.max_tokens, "temperature": self.temperature, "stopSequences": [""]}} - headers = {"Content-Type": "application/json"} - res, chunk_gen = network.stream('POST', url, json=payload, headers=headers, timeout=60) + conn, path = _get_streaming_connection(url) + payload = json.dumps({"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"maxOutputTokens": self.max_tokens, "temperature": self.temperature, "stopSequences": [""]}}) + conn.request("POST", path, body=payload, headers={"Content-Type": "application/json"}) + res = conn.getresponse() - if res.status_code != 200: - for _ in chunk_gen: pass # Drain to prevent resource leak - logger.error(f"{PLUGIN_NAME}: Gemini API {res.status_code}") + if res.status != 200: + logger.error(f"{PLUGIN_NAME}: Gemini API {res.status}") return decoder = json.JSONDecoder() buffer = "" - utf8_decoder = codecs.getincrementaldecoder("utf-8")(errors='replace') - for chunk in chunk_gen: - if not chunk: continue - buffer += utf8_decoder.decode(chunk, final=False) + while True: + chunk = res.read(STREAM_CHUNK_SIZE) + if not chunk: break + buffer += chunk.decode('utf-8', errors='replace') while buffer: buffer = buffer.lstrip() if not buffer: break @@ -765,17 +794,21 @@ class SXNGPlugin(Plugin): except json.JSONDecodeError: break except Exception as e: logger.error(f"{PLUGIN_NAME}: Gemini stream error: {e}") + finally: + if conn: conn.close() def stream_openai_compatible(): + conn = None try: - payload = { + conn, path = _get_streaming_connection(self.endpoint_url) + payload = json.dumps({ "model": self.model, "messages": [{"role": "user", "content": prompt}], "stream": True, "max_tokens": self.max_tokens, "temperature": self.temperature, "stop": [""] - } + }) headers = { "Content-Type": "application/json", "HTTP-Referer": "https://github.com/searxng/searxng", @@ -785,17 +818,18 @@ class SXNGPlugin(Plugin): headers['api-key'] = self.api_key else: headers['Authorization'] = f"Bearer {self.api_key}" - res, chunk_gen = network.stream('POST', self.endpoint_url, json=payload, headers=headers, timeout=60) + conn.request("POST", path, body=payload, headers=headers) + res = conn.getresponse() - if res.status_code != 200: - for _ in chunk_gen: pass - logger.error(f"{PLUGIN_NAME}: {self.provider} API {res.status_code}") + if res.status != 200: + logger.error(f"{PLUGIN_NAME}: {self.provider} API {res.status}") return decoder = json.JSONDecoder() buffer = b"" - for chunk in chunk_gen: - if not chunk: continue + while True: + chunk = res.read(STREAM_CHUNK_SIZE) + if not chunk: break buffer += chunk while b"\n" in buffer: line_bytes, buffer = buffer.split(b"\n", 1) @@ -811,13 +845,16 @@ class SXNGPlugin(Plugin): pass except Exception as e: logger.error(f"{PLUGIN_NAME}: {self.provider} stream error: {e}") + finally: + if conn: conn.close() generator = stream_gemini if self.is_gemini else stream_openai_compatible return Response(generator(), mimetype='text/event-stream', headers={ 'X-Accel-Buffering': 'no', 'Cache-Control': 'no-cache, no-store', 'Connection': 'keep-alive', - 'Content-Encoding': 'identity' + 'Transfer-Encoding': 'chunked', + 'Content-Encoding': 'identity', }) return True