From 332834a1267688a1cef286c14d70a0093e28d052 Mon Sep 17 00:00:00 2001
From: Tyler <68524461+TySP-Dev@users.noreply.github.com>
Date: Sun, 17 May 2026 15:11:01 -0400
Subject: [PATCH] Adding better AI response streaming logic
---
ollama_answers.py | 405 +++++++++++++++++++++++++++++++---------------
1 file changed, 279 insertions(+), 126 deletions(-)
diff --git a/ollama_answers.py b/ollama_answers.py
index 9467e54..bf0f57a 100644
--- a/ollama_answers.py
+++ b/ollama_answers.py
@@ -1,4 +1,4 @@
-import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures
+import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading
from urllib.parse import urlparse
from searx import network
try:
@@ -14,6 +14,14 @@ from markupsafe import Markup
logger = logging.getLogger(__name__)
+try:
+ import valkey as _valkey_mod
+ _VALKEY_AVAILABLE = True
+except ImportError:
+ _VALKEY_AVAILABLE = False
+ _valkey_mod = None
+ logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.")
+
TOKEN_EXPIRY_SEC = 3600
STREAM_CHUNK_SIZE = 512
STREAM_TIMEOUT_SEC = 60
@@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True):
return conn, path
+_VALKEY_POOL = None
+
+
+def _get_valkey_pool():
+ global _VALKEY_POOL
+ if _VALKEY_POOL is None:
+ assert _valkey_mod is not None
+ _VALKEY_POOL = _valkey_mod.ConnectionPool(
+ host=os.getenv('VALKEY_HOST', 'searxng-valkey'),
+ port=int(os.getenv('VALKEY_PORT', 6379)),
+ db=0,
+ decode_responses=True,
+ )
+ return _VALKEY_POOL
+
+
+def _get_valkey():
+ if not _VALKEY_AVAILABLE or _valkey_mod is None:
+ raise RuntimeError("valkey package not installed")
+ return _valkey_mod.Valkey(connection_pool=_get_valkey_pool())
+
+
+def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str):
+ chunks_key = f"ai:job:{job_id}:chunks"
+ status_key = f"ai:job:{job_id}:status"
+ conn = None
+ try:
+ vk = _get_valkey()
+ url = endpoint_url
+ res = None
+ for _ in range(3):
+ conn, path = _get_streaming_connection(url)
+ conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
+ res = conn.getresponse()
+ if res.status in (301, 302, 307, 308):
+ location = res.getheader('Location', '')
+ res.read()
+ conn.close()
+ conn = None
+ if not location:
+ raise RuntimeError(f"Redirect {res.status} with no Location")
+ url = location if location.startswith('http') else \
+ f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
+ continue
+ break
+ else:
+ raise RuntimeError("Too many redirects to Ollama endpoint")
+
+ if res.status != 200:
+ body = res.read(1024).decode('utf-8', errors='replace')
+ raise RuntimeError(f"Ollama error {res.status}: {body[:200]}")
+
+ think_depth = 0
+ pending = ''
+
+ while True:
+ raw_line = res.readline()
+ if not raw_line:
+ break
+ line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n')
+ if not line or not line.startswith('data: '):
+ continue
+ data_str = line[6:]
+ if data_str == '[DONE]':
+ break
+ try:
+ obj = json.loads(data_str)
+ except (json.JSONDecodeError, ValueError):
+ continue
+ choices = obj.get('choices', [])
+ if not choices:
+ continue
+ delta = choices[0].get('delta', {})
+ token = delta.get('content') or ''
+ if not token:
+ continue
+
+ pending += token
+ # Filter ... blocks, push clean content immediately
+ while True:
+ if think_depth == 0:
+ think_start = pending.find('')
+ if think_start == -1:
+ if pending:
+ vk.rpush(chunks_key, pending)
+ vk.expire(chunks_key, 120)
+ pending = ''
+ break
+ else:
+ before = pending[:think_start]
+ if before:
+ vk.rpush(chunks_key, before)
+ vk.expire(chunks_key, 120)
+ pending = pending[think_start + 7:]
+ think_depth = 1
+ else:
+ think_end = pending.find('')
+ if think_end == -1:
+ break
+ else:
+ pending = pending[think_end + 8:]
+ think_depth = 0
+
+ if think_depth == 0 and pending:
+ vk.rpush(chunks_key, pending)
+ vk.expire(chunks_key, 120)
+
+ vk.rpush(chunks_key, '__DONE__')
+ vk.expire(chunks_key, 120)
+ vk.set(status_key, 'done', ex=120)
+
+ except Exception as e:
+ logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True)
+ try:
+ vk2 = _get_valkey()
+ vk2.rpush(chunks_key, f"__ERROR__{e}")
+ vk2.expire(chunks_key, 120)
+ vk2.set(status_key, 'error', ex=120)
+ except Exception:
+ pass
+ finally:
+ if conn:
+ try:
+ conn.close()
+ except Exception:
+ pass
+
+
PLUGIN_NAME = "AI Answers"
DEFAULT_TABS = "general,science,it,news"
@@ -577,9 +713,9 @@ FRONTEND_JS_TEMPLATE = r"""
box.style.display = 'block';
const controller = new AbortController();
- let timeoutId = setTimeout(() => controller.abort(), 60000);
+ let timeoutId = setTimeout(() => controller.abort(), 90000);
const finalQ = __STREAM_Q__;
-
+
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
const bodyObj = { q: finalQ, lang: lang_init, context: ctx, tk: tk_init, model: _selMdl__STREAM_BODY__ };
const res = await fetch(script_root + '/ai-stream', {
@@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r"""
return;
}
- const fullText = (respJson.text || '').trim();
-
- if (!fullText) {
+ const jobId = respJson.job_id;
+ if (!jobId) {
const cursorErr = data.querySelector('.sxng-cursor');
if (cursorErr) cursorErr.remove();
const errSpan = document.createElement('span');
errSpan.style.color = '#bf616a';
- errSpan.textContent = 'No response received. Check API configuration and server logs.';
+ errSpan.textContent = 'No job ID returned. Check server logs.';
data.appendChild(errSpan);
return;
}
- let mainText = fullText;
- const thinkMatch = mainText.match(/^([\s\S]*?)<\/think>\s*/);
- if (thinkMatch) {
- const cursorTh = data.querySelector('.sxng-cursor');
- const details = document.createElement('details');
- details.className = 'sxng-reasoning';
- details.innerHTML = 'Thought Process';
- const thoughtDiv = document.createElement('div');
- thoughtDiv.className = 'sxng-thought-content';
- thoughtDiv.textContent = thinkMatch[1];
- details.appendChild(thoughtDiv);
- if (cursorTh) cursorTh.before(details);
- else data.appendChild(details);
- mainText = mainText.substring(thinkMatch[0].length);
- }
-
let cursor = data.querySelector('.sxng-cursor');
if (!cursor) {
cursor = document.createElement('span');
@@ -646,9 +765,10 @@ FRONTEND_JS_TEMPLATE = r"""
}
let buffer = '';
+ let fullText = '';
const flushBuffer = (force = false) => {
if (!buffer) return;
-
+
if (force) {
const fragment = renderCitations(buffer, urls);
if (cursor) cursor.before(fragment);
@@ -659,9 +779,9 @@ FRONTEND_JS_TEMPLATE = r"""
while (true) {
const match = buffer.match(/(\[\d+(?:,\s*\d+)*\])/);
-
+
if (!match) break;
-
+
const preText = buffer.substring(0, match.index);
if (preText) {
const s = document.createElement('span');
@@ -695,7 +815,7 @@ FRONTEND_JS_TEMPLATE = r"""
cursor.before(s);
}
buffer = buffer.substring(openIdx);
-
+
if (buffer.length > 50) {
const s = document.createElement('span');
s.className = 'sxng-chunk';
@@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r"""
}
};
- let twPos = 0;
- const twBatch = 4;
- await new Promise(resolve => {
- function twTick() {
- if (twPos >= mainText.length) {
- flushBuffer(true);
- resolve();
- return;
- }
- const end = Math.min(twPos + twBatch, mainText.length);
- buffer += mainText.substring(twPos, end);
- twPos = end;
- flushBuffer(false);
- setTimeout(twTick, 8);
+ let offset = 0;
+ const maxPolls = 600;
+ let polls = 0;
+
+ while (polls < maxPolls) {
+ polls++;
+ await new Promise(r => setTimeout(r, 150));
+
+ let statusRes;
+ try {
+ statusRes = await fetch(
+ `${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`,
+ { signal: controller.signal }
+ );
+ } catch (fetchErr) {
+ if (fetchErr.name === 'AbortError') throw fetchErr;
+ continue;
}
- twTick();
- });
+
+ if (statusRes.status === 404) {
+ const cursorE = data.querySelector('.sxng-cursor');
+ if (cursorE) cursorE.remove();
+ const expiredSpan = document.createElement('span');
+ expiredSpan.style.color = '#bf616a';
+ expiredSpan.textContent = 'Response expired. Please search again.';
+ data.appendChild(expiredSpan);
+ return;
+ }
+
+ if (!statusRes.ok) continue;
+
+ const statusData = await statusRes.json();
+
+ if (statusData.error) {
+ const cursorE = data.querySelector('.sxng-cursor');
+ if (cursorE) cursorE.remove();
+ const errSpan2 = document.createElement('span');
+ errSpan2.style.color = '#bf616a';
+ errSpan2.textContent = '⚠️ ' + statusData.error;
+ data.appendChild(errSpan2);
+ return;
+ }
+
+ for (const chunk of (statusData.chunks || [])) {
+ fullText += chunk;
+ buffer += chunk;
+ flushBuffer(false);
+ }
+ offset += (statusData.chunks || []).length;
+
+ if (statusData.done) {
+ flushBuffer(true);
+ break;
+ }
+ }
if (cursor) cursor.remove();
@@ -738,17 +896,16 @@ FRONTEND_JS_TEMPLATE = r"""
}
}
- renderCitationFooter(mainText, urls, data);
+ renderCitationFooter(fullText, urls, data);
- const collectedResponse = mainText;
+ const collectedResponse = fullText;
__INTERACTIVE_JS_COMPLETE__
if (collectedResponse) {
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
}
-
- // Save state if this was an initial generation or a regeneration
+
if (arguments.length === 0 && typeof updateState === 'function') {
updateState();
}
@@ -1094,86 +1251,82 @@ class SXNGPlugin(Plugin):
{numbered_instructions}
"""
- def call_ollama():
- conn = None
- try:
- payload_dict = {
- "model": effective_model,
- "messages": [
- {"role": "system", "content": SYSTEM},
- {"role": "user", "content": prompt},
- {"role": "assistant", "content": ""},
- ],
- "stream": False,
- "max_tokens": self.max_tokens,
- "temperature": self.temperature,
- }
- payload = json.dumps(payload_dict)
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}",
- }
- url = self.endpoint_url
- res = None # type: ignore[assignment]
- for _ in range(3):
- conn, path = _get_streaming_connection(url)
- conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
- res = conn.getresponse()
- if res.status in (301, 302, 307, 308):
- location = res.getheader('Location', '')
- res.read()
- conn.close()
- conn = None
- if not location:
- return '', f"Redirect {res.status} with no Location header"
- url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
- logger.info(f"{PLUGIN_NAME}: Following redirect to {url}")
- continue
- break
- if res.status != 200:
- body = res.read(1024).decode('utf-8', errors='replace')
- logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}")
- return '', f"Ollama error {res.status}"
- obj = json.loads(res.read().decode('utf-8', errors='replace'))
- if 'error' in obj:
- err = obj['error']
- msg = err.get('message', str(err)) if isinstance(err, dict) else str(err)
- return '', msg
- choices = obj.get('choices', [])
- if not choices:
- return '', "No choices in Ollama response."
- message = choices[0].get('message', {})
- content = message.get('content') or ''
- reasoning = message.get('reasoning') or message.get('reasoning_content') or ''
- content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip()
- if not content and reasoning:
- logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field")
- lines = reasoning.splitlines()
- header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$')
- last_header_idx = -1
- for i, line in enumerate(lines):
- if header_re.match(line):
- last_header_idx = i
- if last_header_idx >= 0 and last_header_idx < len(lines) - 1:
- content = '\n'.join(lines[last_header_idx + 1:]).strip()
- if not content:
- paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()]
- content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else ''
- if reasoning and content:
- full = (f"\n{reasoning}\n\n\n" if reasoning else "") + content
- else:
- full = content
- full = re.sub(r'.*?', '', full, flags=re.DOTALL).strip()
- return full, None
- except Exception as e:
- logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True)
- return '', f"Connection Error: {e}"
- finally:
- if conn:
- conn.close()
+ job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16]
+
+ payload_dict = {
+ "model": effective_model,
+ "messages": [
+ {"role": "system", "content": SYSTEM},
+ {"role": "user", "content": prompt},
+ {"role": "assistant", "content": ""},
+ ],
+ "stream": True,
+ "max_tokens": self.max_tokens,
+ "temperature": self.temperature,
+ }
+ stream_payload = json.dumps(payload_dict)
+ stream_headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {self.api_key}",
+ }
+
+ try:
+ vk = _get_valkey()
+ vk.set(f"ai:job:{job_id}:status", "running", ex=120)
+ except Exception as e:
+ logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True)
+ return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503
+
+ t = threading.Thread(
+ target=stream_to_valkey,
+ args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model),
+ daemon=True,
+ )
+ t.start()
+
+ return jsonify({"job_id": job_id})
+
+ @app.route('/ai-status/', methods=['GET'])
+ def ai_status(job_id):
+ token = request.args.get('tk', '')
+ try:
+ ts, sig = token.rsplit('.', 1)
+ expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest()
+ if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC:
+ abort(403)
+ except (ValueError, KeyError, AttributeError):
+ abort(403)
+
+ offset = max(0, int(request.args.get('offset', 0)))
+ chunks_key = f"ai:job:{job_id}:chunks"
+ status_key = f"ai:job:{job_id}:status"
+
+ try:
+ vk = _get_valkey()
+ status = vk.get(status_key)
+ if status is None:
+ return jsonify({"error": "Job not found or expired"}), 404
+ raw_chunks = vk.lrange(chunks_key, offset, -1)
+ except Exception as e:
+ logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True)
+ return jsonify({"error": "Stream service temporarily unavailable"}), 503
+
+ done = False
+ error = None
+ chunks = []
+ for chunk in raw_chunks:
+ if chunk == '__DONE__':
+ done = True
+ break
+ elif chunk.startswith('__ERROR__'):
+ error = chunk[9:]
+ done = True
+ break
+ else:
+ chunks.append(chunk)
+
+ return jsonify({"chunks": chunks, "done": done, "error": error})
- text, error = call_ollama()
- return jsonify({"text": text, "error": error})
return True
def _fetch_page_text(self, url: str, timeout: int = 5) -> str: