fix: use http.client for LLM streaming (bypass 64KB buffer)
This commit is contained in:
+57
-20
@@ -1,6 +1,10 @@
|
||||
import json, os, logging, base64, time, hashlib, codecs, re
|
||||
import json, os, logging, base64, time, hashlib, codecs, re, http.client, ssl
|
||||
from urllib.parse import urlparse
|
||||
from searx import network
|
||||
try:
|
||||
from searx.network import get_network
|
||||
except ImportError:
|
||||
get_network = None # Graceful fallback for test/demo environments
|
||||
from flask import Response, request, abort, jsonify
|
||||
from searx.plugins import Plugin, PluginInfo
|
||||
from searx.result_types import EngineResults
|
||||
@@ -10,6 +14,30 @@ from markupsafe import Markup
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKEN_EXPIRY_SEC = 3600
|
||||
STREAM_CHUNK_SIZE = 128
|
||||
STREAM_TIMEOUT_SEC = 60
|
||||
|
||||
def _get_streaming_connection(url: str):
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname
|
||||
port = parsed.port or (443 if parsed.scheme == 'https' else 80)
|
||||
path = parsed.path + ('?' + parsed.query if parsed.query else '')
|
||||
|
||||
verify_ssl = True
|
||||
if get_network is not None:
|
||||
try:
|
||||
net = get_network()
|
||||
verify_ssl = getattr(net, 'verify', True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if parsed.scheme == 'https':
|
||||
ctx = ssl.create_default_context() if verify_ssl else ssl._create_unverified_context()
|
||||
conn = http.client.HTTPSConnection(host, port, timeout=STREAM_TIMEOUT_SEC, context=ctx)
|
||||
else:
|
||||
conn = http.client.HTTPConnection(host, port, timeout=STREAM_TIMEOUT_SEC)
|
||||
|
||||
return conn, path
|
||||
|
||||
|
||||
|
||||
@@ -731,22 +759,23 @@ class SXNGPlugin(Plugin):
|
||||
else:
|
||||
url = f"{self.endpoint_url}?key={self.api_key}"
|
||||
|
||||
conn = None
|
||||
try:
|
||||
payload = {"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"maxOutputTokens": self.max_tokens, "temperature": self.temperature, "stopSequences": ["</answer>"]}}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
res, chunk_gen = network.stream('POST', url, json=payload, headers=headers, timeout=60)
|
||||
conn, path = _get_streaming_connection(url)
|
||||
payload = json.dumps({"contents": [{"parts": [{"text": prompt}]}], "generationConfig": {"maxOutputTokens": self.max_tokens, "temperature": self.temperature, "stopSequences": ["</answer>"]}})
|
||||
conn.request("POST", path, body=payload, headers={"Content-Type": "application/json"})
|
||||
res = conn.getresponse()
|
||||
|
||||
if res.status_code != 200:
|
||||
for _ in chunk_gen: pass # Drain to prevent resource leak
|
||||
logger.error(f"{PLUGIN_NAME}: Gemini API {res.status_code}")
|
||||
if res.status != 200:
|
||||
logger.error(f"{PLUGIN_NAME}: Gemini API {res.status}")
|
||||
return
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
buffer = ""
|
||||
utf8_decoder = codecs.getincrementaldecoder("utf-8")(errors='replace')
|
||||
for chunk in chunk_gen:
|
||||
if not chunk: continue
|
||||
buffer += utf8_decoder.decode(chunk, final=False)
|
||||
while True:
|
||||
chunk = res.read(STREAM_CHUNK_SIZE)
|
||||
if not chunk: break
|
||||
buffer += chunk.decode('utf-8', errors='replace')
|
||||
while buffer:
|
||||
buffer = buffer.lstrip()
|
||||
if not buffer: break
|
||||
@@ -765,17 +794,21 @@ class SXNGPlugin(Plugin):
|
||||
except json.JSONDecodeError: break
|
||||
except Exception as e:
|
||||
logger.error(f"{PLUGIN_NAME}: Gemini stream error: {e}")
|
||||
finally:
|
||||
if conn: conn.close()
|
||||
|
||||
def stream_openai_compatible():
|
||||
conn = None
|
||||
try:
|
||||
payload = {
|
||||
conn, path = _get_streaming_connection(self.endpoint_url)
|
||||
payload = json.dumps({
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": True,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"stop": ["</answer>"]
|
||||
}
|
||||
})
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"HTTP-Referer": "https://github.com/searxng/searxng",
|
||||
@@ -785,17 +818,18 @@ class SXNGPlugin(Plugin):
|
||||
headers['api-key'] = self.api_key
|
||||
else:
|
||||
headers['Authorization'] = f"Bearer {self.api_key}"
|
||||
res, chunk_gen = network.stream('POST', self.endpoint_url, json=payload, headers=headers, timeout=60)
|
||||
conn.request("POST", path, body=payload, headers=headers)
|
||||
res = conn.getresponse()
|
||||
|
||||
if res.status_code != 200:
|
||||
for _ in chunk_gen: pass
|
||||
logger.error(f"{PLUGIN_NAME}: {self.provider} API {res.status_code}")
|
||||
if res.status != 200:
|
||||
logger.error(f"{PLUGIN_NAME}: {self.provider} API {res.status}")
|
||||
return
|
||||
|
||||
decoder = json.JSONDecoder()
|
||||
buffer = b""
|
||||
for chunk in chunk_gen:
|
||||
if not chunk: continue
|
||||
while True:
|
||||
chunk = res.read(STREAM_CHUNK_SIZE)
|
||||
if not chunk: break
|
||||
buffer += chunk
|
||||
while b"\n" in buffer:
|
||||
line_bytes, buffer = buffer.split(b"\n", 1)
|
||||
@@ -811,13 +845,16 @@ class SXNGPlugin(Plugin):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{PLUGIN_NAME}: {self.provider} stream error: {e}")
|
||||
finally:
|
||||
if conn: conn.close()
|
||||
|
||||
generator = stream_gemini if self.is_gemini else stream_openai_compatible
|
||||
return Response(generator(), mimetype='text/event-stream', headers={
|
||||
'X-Accel-Buffering': 'no',
|
||||
'Cache-Control': 'no-cache, no-store',
|
||||
'Connection': 'keep-alive',
|
||||
'Content-Encoding': 'identity'
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Content-Encoding': 'identity',
|
||||
})
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user