Adding better AI response streaming logic

This commit is contained in:
Tyler
2026-05-17 15:11:01 -04:00
parent 59c46222b5
commit 332834a126
+279 -126
View File
@@ -1,4 +1,4 @@
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading
from urllib.parse import urlparse from urllib.parse import urlparse
from searx import network from searx import network
try: try:
@@ -14,6 +14,14 @@ from markupsafe import Markup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import valkey as _valkey_mod
_VALKEY_AVAILABLE = True
except ImportError:
_VALKEY_AVAILABLE = False
_valkey_mod = None
logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.")
TOKEN_EXPIRY_SEC = 3600 TOKEN_EXPIRY_SEC = 3600
STREAM_CHUNK_SIZE = 512 STREAM_CHUNK_SIZE = 512
STREAM_TIMEOUT_SEC = 60 STREAM_TIMEOUT_SEC = 60
@@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True):
return conn, path return conn, path
_VALKEY_POOL = None
def _get_valkey_pool():
global _VALKEY_POOL
if _VALKEY_POOL is None:
assert _valkey_mod is not None
_VALKEY_POOL = _valkey_mod.ConnectionPool(
host=os.getenv('VALKEY_HOST', 'searxng-valkey'),
port=int(os.getenv('VALKEY_PORT', 6379)),
db=0,
decode_responses=True,
)
return _VALKEY_POOL
def _get_valkey():
if not _VALKEY_AVAILABLE or _valkey_mod is None:
raise RuntimeError("valkey package not installed")
return _valkey_mod.Valkey(connection_pool=_get_valkey_pool())
def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str):
chunks_key = f"ai:job:{job_id}:chunks"
status_key = f"ai:job:{job_id}:status"
conn = None
try:
vk = _get_valkey()
url = endpoint_url
res = None
for _ in range(3):
conn, path = _get_streaming_connection(url)
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
res = conn.getresponse()
if res.status in (301, 302, 307, 308):
location = res.getheader('Location', '')
res.read()
conn.close()
conn = None
if not location:
raise RuntimeError(f"Redirect {res.status} with no Location")
url = location if location.startswith('http') else \
f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
continue
break
else:
raise RuntimeError("Too many redirects to Ollama endpoint")
if res.status != 200:
body = res.read(1024).decode('utf-8', errors='replace')
raise RuntimeError(f"Ollama error {res.status}: {body[:200]}")
think_depth = 0
pending = ''
while True:
raw_line = res.readline()
if not raw_line:
break
line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n')
if not line or not line.startswith('data: '):
continue
data_str = line[6:]
if data_str == '[DONE]':
break
try:
obj = json.loads(data_str)
except (json.JSONDecodeError, ValueError):
continue
choices = obj.get('choices', [])
if not choices:
continue
delta = choices[0].get('delta', {})
token = delta.get('content') or ''
if not token:
continue
pending += token
# Filter <think>...</think> blocks, push clean content immediately
while True:
if think_depth == 0:
think_start = pending.find('<think>')
if think_start == -1:
if pending:
vk.rpush(chunks_key, pending)
vk.expire(chunks_key, 120)
pending = ''
break
else:
before = pending[:think_start]
if before:
vk.rpush(chunks_key, before)
vk.expire(chunks_key, 120)
pending = pending[think_start + 7:]
think_depth = 1
else:
think_end = pending.find('</think>')
if think_end == -1:
break
else:
pending = pending[think_end + 8:]
think_depth = 0
if think_depth == 0 and pending:
vk.rpush(chunks_key, pending)
vk.expire(chunks_key, 120)
vk.rpush(chunks_key, '__DONE__')
vk.expire(chunks_key, 120)
vk.set(status_key, 'done', ex=120)
except Exception as e:
logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True)
try:
vk2 = _get_valkey()
vk2.rpush(chunks_key, f"__ERROR__{e}")
vk2.expire(chunks_key, 120)
vk2.set(status_key, 'error', ex=120)
except Exception:
pass
finally:
if conn:
try:
conn.close()
except Exception:
pass
PLUGIN_NAME = "AI Answers" PLUGIN_NAME = "AI Answers"
DEFAULT_TABS = "general,science,it,news" DEFAULT_TABS = "general,science,it,news"
@@ -577,9 +713,9 @@ FRONTEND_JS_TEMPLATE = r"""
box.style.display = 'block'; box.style.display = 'block';
const controller = new AbortController(); const controller = new AbortController();
let timeoutId = setTimeout(() => controller.abort(), 60000); let timeoutId = setTimeout(() => controller.abort(), 90000);
const finalQ = __STREAM_Q__; const finalQ = __STREAM_Q__;
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value; const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
const bodyObj = { q: finalQ, lang: lang_init, context: ctx, tk: tk_init, model: _selMdl__STREAM_BODY__ }; const bodyObj = { q: finalQ, lang: lang_init, context: ctx, tk: tk_init, model: _selMdl__STREAM_BODY__ };
const res = await fetch(script_root + '/ai-stream', { const res = await fetch(script_root + '/ai-stream', {
@@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r"""
return; return;
} }
const fullText = (respJson.text || '').trim(); const jobId = respJson.job_id;
if (!jobId) {
if (!fullText) {
const cursorErr = data.querySelector('.sxng-cursor'); const cursorErr = data.querySelector('.sxng-cursor');
if (cursorErr) cursorErr.remove(); if (cursorErr) cursorErr.remove();
const errSpan = document.createElement('span'); const errSpan = document.createElement('span');
errSpan.style.color = '#bf616a'; errSpan.style.color = '#bf616a';
errSpan.textContent = 'No response received. Check API configuration and server logs.'; errSpan.textContent = 'No job ID returned. Check server logs.';
data.appendChild(errSpan); data.appendChild(errSpan);
return; return;
} }
let mainText = fullText;
const thinkMatch = mainText.match(/^<think>([\s\S]*?)<\/think>\s*/);
if (thinkMatch) {
const cursorTh = data.querySelector('.sxng-cursor');
const details = document.createElement('details');
details.className = 'sxng-reasoning';
details.innerHTML = '<summary>Thought Process</summary>';
const thoughtDiv = document.createElement('div');
thoughtDiv.className = 'sxng-thought-content';
thoughtDiv.textContent = thinkMatch[1];
details.appendChild(thoughtDiv);
if (cursorTh) cursorTh.before(details);
else data.appendChild(details);
mainText = mainText.substring(thinkMatch[0].length);
}
let cursor = data.querySelector('.sxng-cursor'); let cursor = data.querySelector('.sxng-cursor');
if (!cursor) { if (!cursor) {
cursor = document.createElement('span'); cursor = document.createElement('span');
@@ -646,9 +765,10 @@ FRONTEND_JS_TEMPLATE = r"""
} }
let buffer = ''; let buffer = '';
let fullText = '';
const flushBuffer = (force = false) => { const flushBuffer = (force = false) => {
if (!buffer) return; if (!buffer) return;
if (force) { if (force) {
const fragment = renderCitations(buffer, urls); const fragment = renderCitations(buffer, urls);
if (cursor) cursor.before(fragment); if (cursor) cursor.before(fragment);
@@ -659,9 +779,9 @@ FRONTEND_JS_TEMPLATE = r"""
while (true) { while (true) {
const match = buffer.match(/(\[\d+(?:,\s*\d+)*\])/); const match = buffer.match(/(\[\d+(?:,\s*\d+)*\])/);
if (!match) break; if (!match) break;
const preText = buffer.substring(0, match.index); const preText = buffer.substring(0, match.index);
if (preText) { if (preText) {
const s = document.createElement('span'); const s = document.createElement('span');
@@ -695,7 +815,7 @@ FRONTEND_JS_TEMPLATE = r"""
cursor.before(s); cursor.before(s);
} }
buffer = buffer.substring(openIdx); buffer = buffer.substring(openIdx);
if (buffer.length > 50) { if (buffer.length > 50) {
const s = document.createElement('span'); const s = document.createElement('span');
s.className = 'sxng-chunk'; s.className = 'sxng-chunk';
@@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r"""
} }
}; };
let twPos = 0; let offset = 0;
const twBatch = 4; const maxPolls = 600;
await new Promise(resolve => { let polls = 0;
function twTick() {
if (twPos >= mainText.length) { while (polls < maxPolls) {
flushBuffer(true); polls++;
resolve(); await new Promise(r => setTimeout(r, 150));
return;
} let statusRes;
const end = Math.min(twPos + twBatch, mainText.length); try {
buffer += mainText.substring(twPos, end); statusRes = await fetch(
twPos = end; `${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`,
flushBuffer(false); { signal: controller.signal }
setTimeout(twTick, 8); );
} catch (fetchErr) {
if (fetchErr.name === 'AbortError') throw fetchErr;
continue;
} }
twTick();
}); if (statusRes.status === 404) {
const cursorE = data.querySelector('.sxng-cursor');
if (cursorE) cursorE.remove();
const expiredSpan = document.createElement('span');
expiredSpan.style.color = '#bf616a';
expiredSpan.textContent = 'Response expired. Please search again.';
data.appendChild(expiredSpan);
return;
}
if (!statusRes.ok) continue;
const statusData = await statusRes.json();
if (statusData.error) {
const cursorE = data.querySelector('.sxng-cursor');
if (cursorE) cursorE.remove();
const errSpan2 = document.createElement('span');
errSpan2.style.color = '#bf616a';
errSpan2.textContent = '⚠️ ' + statusData.error;
data.appendChild(errSpan2);
return;
}
for (const chunk of (statusData.chunks || [])) {
fullText += chunk;
buffer += chunk;
flushBuffer(false);
}
offset += (statusData.chunks || []).length;
if (statusData.done) {
flushBuffer(true);
break;
}
}
if (cursor) cursor.remove(); if (cursor) cursor.remove();
@@ -738,17 +896,16 @@ FRONTEND_JS_TEMPLATE = r"""
} }
} }
renderCitationFooter(mainText, urls, data); renderCitationFooter(fullText, urls, data);
const collectedResponse = mainText; const collectedResponse = fullText;
__INTERACTIVE_JS_COMPLETE__ __INTERACTIVE_JS_COMPLETE__
if (collectedResponse) { if (collectedResponse) {
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()}); conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
} }
// Save state if this was an initial generation or a regeneration
if (arguments.length === 0 && typeof updateState === 'function') { if (arguments.length === 0 && typeof updateState === 'function') {
updateState(); updateState();
} }
@@ -1094,86 +1251,82 @@ class SXNGPlugin(Plugin):
{numbered_instructions} {numbered_instructions}
</CORE_DIRECTIVES>""" </CORE_DIRECTIVES>"""
def call_ollama(): job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16]
conn = None
try: payload_dict = {
payload_dict = { "model": effective_model,
"model": effective_model, "messages": [
"messages": [ {"role": "system", "content": SYSTEM},
{"role": "system", "content": SYSTEM}, {"role": "user", "content": prompt},
{"role": "user", "content": prompt}, {"role": "assistant", "content": ""},
{"role": "assistant", "content": ""}, ],
], "stream": True,
"stream": False, "max_tokens": self.max_tokens,
"max_tokens": self.max_tokens, "temperature": self.temperature,
"temperature": self.temperature, }
} stream_payload = json.dumps(payload_dict)
payload = json.dumps(payload_dict) stream_headers = {
headers = { "Content-Type": "application/json",
"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}",
"Authorization": f"Bearer {self.api_key}", }
}
url = self.endpoint_url try:
res = None # type: ignore[assignment] vk = _get_valkey()
for _ in range(3): vk.set(f"ai:job:{job_id}:status", "running", ex=120)
conn, path = _get_streaming_connection(url) except Exception as e:
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers) logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True)
res = conn.getresponse() return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503
if res.status in (301, 302, 307, 308):
location = res.getheader('Location', '') t = threading.Thread(
res.read() target=stream_to_valkey,
conn.close() args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model),
conn = None daemon=True,
if not location: )
return '', f"Redirect {res.status} with no Location header" t.start()
url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
logger.info(f"{PLUGIN_NAME}: Following redirect to {url}") return jsonify({"job_id": job_id})
continue
break @app.route('/ai-status/<job_id>', methods=['GET'])
if res.status != 200: def ai_status(job_id):
body = res.read(1024).decode('utf-8', errors='replace') token = request.args.get('tk', '')
logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}") try:
return '', f"Ollama error {res.status}" ts, sig = token.rsplit('.', 1)
obj = json.loads(res.read().decode('utf-8', errors='replace')) expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest()
if 'error' in obj: if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC:
err = obj['error'] abort(403)
msg = err.get('message', str(err)) if isinstance(err, dict) else str(err) except (ValueError, KeyError, AttributeError):
return '', msg abort(403)
choices = obj.get('choices', [])
if not choices: offset = max(0, int(request.args.get('offset', 0)))
return '', "No choices in Ollama response." chunks_key = f"ai:job:{job_id}:chunks"
message = choices[0].get('message', {}) status_key = f"ai:job:{job_id}:status"
content = message.get('content') or ''
reasoning = message.get('reasoning') or message.get('reasoning_content') or '' try:
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip() vk = _get_valkey()
if not content and reasoning: status = vk.get(status_key)
logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field") if status is None:
lines = reasoning.splitlines() return jsonify({"error": "Job not found or expired"}), 404
header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$') raw_chunks = vk.lrange(chunks_key, offset, -1)
last_header_idx = -1 except Exception as e:
for i, line in enumerate(lines): logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True)
if header_re.match(line): return jsonify({"error": "Stream service temporarily unavailable"}), 503
last_header_idx = i
if last_header_idx >= 0 and last_header_idx < len(lines) - 1: done = False
content = '\n'.join(lines[last_header_idx + 1:]).strip() error = None
if not content: chunks = []
paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()] for chunk in raw_chunks:
content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else '' if chunk == '__DONE__':
if reasoning and content: done = True
full = (f"<think>\n{reasoning}\n</think>\n\n" if reasoning else "") + content break
else: elif chunk.startswith('__ERROR__'):
full = content error = chunk[9:]
full = re.sub(r'<think>.*?</think>', '', full, flags=re.DOTALL).strip() done = True
return full, None break
except Exception as e: else:
logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True) chunks.append(chunk)
return '', f"Connection Error: {e}"
finally: return jsonify({"chunks": chunks, "done": done, "error": error})
if conn:
conn.close()
text, error = call_ollama()
return jsonify({"text": text, "error": error})
return True return True
def _fetch_page_text(self, url: str, timeout: int = 5) -> str: def _fetch_page_text(self, url: str, timeout: int = 5) -> str: