Adding better AI response streaming logic

This commit is contained in:
Tyler
2026-05-17 15:11:01 -04:00
parent 59c46222b5
commit 332834a126
+258 -105
View File
@@ -1,4 +1,4 @@
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading
from urllib.parse import urlparse
from searx import network
try:
@@ -14,6 +14,14 @@ from markupsafe import Markup
logger = logging.getLogger(__name__)
try:
import valkey as _valkey_mod
_VALKEY_AVAILABLE = True
except ImportError:
_VALKEY_AVAILABLE = False
_valkey_mod = None
logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.")
TOKEN_EXPIRY_SEC = 3600
STREAM_CHUNK_SIZE = 512
STREAM_TIMEOUT_SEC = 60
@@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True):
return conn, path
_VALKEY_POOL = None
def _get_valkey_pool():
global _VALKEY_POOL
if _VALKEY_POOL is None:
assert _valkey_mod is not None
_VALKEY_POOL = _valkey_mod.ConnectionPool(
host=os.getenv('VALKEY_HOST', 'searxng-valkey'),
port=int(os.getenv('VALKEY_PORT', 6379)),
db=0,
decode_responses=True,
)
return _VALKEY_POOL
def _get_valkey():
if not _VALKEY_AVAILABLE or _valkey_mod is None:
raise RuntimeError("valkey package not installed")
return _valkey_mod.Valkey(connection_pool=_get_valkey_pool())
def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str):
chunks_key = f"ai:job:{job_id}:chunks"
status_key = f"ai:job:{job_id}:status"
conn = None
try:
vk = _get_valkey()
url = endpoint_url
res = None
for _ in range(3):
conn, path = _get_streaming_connection(url)
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
res = conn.getresponse()
if res.status in (301, 302, 307, 308):
location = res.getheader('Location', '')
res.read()
conn.close()
conn = None
if not location:
raise RuntimeError(f"Redirect {res.status} with no Location")
url = location if location.startswith('http') else \
f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
continue
break
else:
raise RuntimeError("Too many redirects to Ollama endpoint")
if res.status != 200:
body = res.read(1024).decode('utf-8', errors='replace')
raise RuntimeError(f"Ollama error {res.status}: {body[:200]}")
think_depth = 0
pending = ''
while True:
raw_line = res.readline()
if not raw_line:
break
line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n')
if not line or not line.startswith('data: '):
continue
data_str = line[6:]
if data_str == '[DONE]':
break
try:
obj = json.loads(data_str)
except (json.JSONDecodeError, ValueError):
continue
choices = obj.get('choices', [])
if not choices:
continue
delta = choices[0].get('delta', {})
token = delta.get('content') or ''
if not token:
continue
pending += token
# Filter <think>...</think> blocks, push clean content immediately
while True:
if think_depth == 0:
think_start = pending.find('<think>')
if think_start == -1:
if pending:
vk.rpush(chunks_key, pending)
vk.expire(chunks_key, 120)
pending = ''
break
else:
before = pending[:think_start]
if before:
vk.rpush(chunks_key, before)
vk.expire(chunks_key, 120)
pending = pending[think_start + 7:]
think_depth = 1
else:
think_end = pending.find('</think>')
if think_end == -1:
break
else:
pending = pending[think_end + 8:]
think_depth = 0
if think_depth == 0 and pending:
vk.rpush(chunks_key, pending)
vk.expire(chunks_key, 120)
vk.rpush(chunks_key, '__DONE__')
vk.expire(chunks_key, 120)
vk.set(status_key, 'done', ex=120)
except Exception as e:
logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True)
try:
vk2 = _get_valkey()
vk2.rpush(chunks_key, f"__ERROR__{e}")
vk2.expire(chunks_key, 120)
vk2.set(status_key, 'error', ex=120)
except Exception:
pass
finally:
if conn:
try:
conn.close()
except Exception:
pass
PLUGIN_NAME = "AI Answers"
DEFAULT_TABS = "general,science,it,news"
@@ -577,7 +713,7 @@ FRONTEND_JS_TEMPLATE = r"""
box.style.display = 'block';
const controller = new AbortController();
let timeoutId = setTimeout(() => controller.abort(), 60000);
let timeoutId = setTimeout(() => controller.abort(), 90000);
const finalQ = __STREAM_Q__;
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
@@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r"""
return;
}
const fullText = (respJson.text || '').trim();
if (!fullText) {
const jobId = respJson.job_id;
if (!jobId) {
const cursorErr = data.querySelector('.sxng-cursor');
if (cursorErr) cursorErr.remove();
const errSpan = document.createElement('span');
errSpan.style.color = '#bf616a';
errSpan.textContent = 'No response received. Check API configuration and server logs.';
errSpan.textContent = 'No job ID returned. Check server logs.';
data.appendChild(errSpan);
return;
}
let mainText = fullText;
const thinkMatch = mainText.match(/^<think>([\s\S]*?)<\/think>\s*/);
if (thinkMatch) {
const cursorTh = data.querySelector('.sxng-cursor');
const details = document.createElement('details');
details.className = 'sxng-reasoning';
details.innerHTML = '<summary>Thought Process</summary>';
const thoughtDiv = document.createElement('div');
thoughtDiv.className = 'sxng-thought-content';
thoughtDiv.textContent = thinkMatch[1];
details.appendChild(thoughtDiv);
if (cursorTh) cursorTh.before(details);
else data.appendChild(details);
mainText = mainText.substring(thinkMatch[0].length);
}
let cursor = data.querySelector('.sxng-cursor');
if (!cursor) {
cursor = document.createElement('span');
@@ -646,6 +765,7 @@ FRONTEND_JS_TEMPLATE = r"""
}
let buffer = '';
let fullText = '';
const flushBuffer = (force = false) => {
if (!buffer) return;
@@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r"""
}
};
let twPos = 0;
const twBatch = 4;
await new Promise(resolve => {
function twTick() {
if (twPos >= mainText.length) {
flushBuffer(true);
resolve();
let offset = 0;
const maxPolls = 600;
let polls = 0;
while (polls < maxPolls) {
polls++;
await new Promise(r => setTimeout(r, 150));
let statusRes;
try {
statusRes = await fetch(
`${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`,
{ signal: controller.signal }
);
} catch (fetchErr) {
if (fetchErr.name === 'AbortError') throw fetchErr;
continue;
}
if (statusRes.status === 404) {
const cursorE = data.querySelector('.sxng-cursor');
if (cursorE) cursorE.remove();
const expiredSpan = document.createElement('span');
expiredSpan.style.color = '#bf616a';
expiredSpan.textContent = 'Response expired. Please search again.';
data.appendChild(expiredSpan);
return;
}
const end = Math.min(twPos + twBatch, mainText.length);
buffer += mainText.substring(twPos, end);
twPos = end;
flushBuffer(false);
setTimeout(twTick, 8);
if (!statusRes.ok) continue;
const statusData = await statusRes.json();
if (statusData.error) {
const cursorE = data.querySelector('.sxng-cursor');
if (cursorE) cursorE.remove();
const errSpan2 = document.createElement('span');
errSpan2.style.color = '#bf616a';
errSpan2.textContent = '⚠️ ' + statusData.error;
data.appendChild(errSpan2);
return;
}
for (const chunk of (statusData.chunks || [])) {
fullText += chunk;
buffer += chunk;
flushBuffer(false);
}
offset += (statusData.chunks || []).length;
if (statusData.done) {
flushBuffer(true);
break;
}
}
twTick();
});
if (cursor) cursor.remove();
@@ -738,9 +896,9 @@ FRONTEND_JS_TEMPLATE = r"""
}
}
renderCitationFooter(mainText, urls, data);
renderCitationFooter(fullText, urls, data);
const collectedResponse = mainText;
const collectedResponse = fullText;
__INTERACTIVE_JS_COMPLETE__
@@ -748,7 +906,6 @@ FRONTEND_JS_TEMPLATE = r"""
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
}
// Save state if this was an initial generation or a regeneration
if (arguments.length === 0 && typeof updateState === 'function') {
updateState();
}
@@ -1094,9 +1251,8 @@ class SXNGPlugin(Plugin):
{numbered_instructions}
</CORE_DIRECTIVES>"""
def call_ollama():
conn = None
try:
job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16]
payload_dict = {
"model": effective_model,
"messages": [
@@ -1104,76 +1260,73 @@ class SXNGPlugin(Plugin):
{"role": "user", "content": prompt},
{"role": "assistant", "content": ""},
],
"stream": False,
"stream": True,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
}
payload = json.dumps(payload_dict)
headers = {
stream_payload = json.dumps(payload_dict)
stream_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
url = self.endpoint_url
res = None # type: ignore[assignment]
for _ in range(3):
conn, path = _get_streaming_connection(url)
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
res = conn.getresponse()
if res.status in (301, 302, 307, 308):
location = res.getheader('Location', '')
res.read()
conn.close()
conn = None
if not location:
return '', f"Redirect {res.status} with no Location header"
url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
logger.info(f"{PLUGIN_NAME}: Following redirect to {url}")
continue
break
if res.status != 200:
body = res.read(1024).decode('utf-8', errors='replace')
logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}")
return '', f"Ollama error {res.status}"
obj = json.loads(res.read().decode('utf-8', errors='replace'))
if 'error' in obj:
err = obj['error']
msg = err.get('message', str(err)) if isinstance(err, dict) else str(err)
return '', msg
choices = obj.get('choices', [])
if not choices:
return '', "No choices in Ollama response."
message = choices[0].get('message', {})
content = message.get('content') or ''
reasoning = message.get('reasoning') or message.get('reasoning_content') or ''
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
if not content and reasoning:
logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field")
lines = reasoning.splitlines()
header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$')
last_header_idx = -1
for i, line in enumerate(lines):
if header_re.match(line):
last_header_idx = i
if last_header_idx >= 0 and last_header_idx < len(lines) - 1:
content = '\n'.join(lines[last_header_idx + 1:]).strip()
if not content:
paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()]
content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else ''
if reasoning and content:
full = (f"<think>\n{reasoning}\n</think>\n\n" if reasoning else "") + content
else:
full = content
full = re.sub(r'<think>.*?</think>', '', full, flags=re.DOTALL).strip()
return full, None
except Exception as e:
logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True)
return '', f"Connection Error: {e}"
finally:
if conn:
conn.close()
text, error = call_ollama()
return jsonify({"text": text, "error": error})
try:
vk = _get_valkey()
vk.set(f"ai:job:{job_id}:status", "running", ex=120)
except Exception as e:
logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True)
return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503
t = threading.Thread(
target=stream_to_valkey,
args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model),
daemon=True,
)
t.start()
return jsonify({"job_id": job_id})
@app.route('/ai-status/<job_id>', methods=['GET'])
def ai_status(job_id):
token = request.args.get('tk', '')
try:
ts, sig = token.rsplit('.', 1)
expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest()
if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC:
abort(403)
except (ValueError, KeyError, AttributeError):
abort(403)
offset = max(0, int(request.args.get('offset', 0)))
chunks_key = f"ai:job:{job_id}:chunks"
status_key = f"ai:job:{job_id}:status"
try:
vk = _get_valkey()
status = vk.get(status_key)
if status is None:
return jsonify({"error": "Job not found or expired"}), 404
raw_chunks = vk.lrange(chunks_key, offset, -1)
except Exception as e:
logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True)
return jsonify({"error": "Stream service temporarily unavailable"}), 503
done = False
error = None
chunks = []
for chunk in raw_chunks:
if chunk == '__DONE__':
done = True
break
elif chunk.startswith('__ERROR__'):
error = chunk[9:]
done = True
break
else:
chunks.append(chunk)
return jsonify({"chunks": chunks, "done": done, "error": error})
return True
def _fetch_page_text(self, url: str, timeout: int = 5) -> str: