Adding better AI response streaming logic
This commit is contained in:
+273
-120
@@ -1,4 +1,4 @@
|
||||
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures
|
||||
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading
|
||||
from urllib.parse import urlparse
|
||||
from searx import network
|
||||
try:
|
||||
@@ -14,6 +14,14 @@ from markupsafe import Markup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import valkey as _valkey_mod
|
||||
_VALKEY_AVAILABLE = True
|
||||
except ImportError:
|
||||
_VALKEY_AVAILABLE = False
|
||||
_valkey_mod = None
|
||||
logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.")
|
||||
|
||||
TOKEN_EXPIRY_SEC = 3600
|
||||
STREAM_CHUNK_SIZE = 512
|
||||
STREAM_TIMEOUT_SEC = 60
|
||||
@@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True):
|
||||
return conn, path
|
||||
|
||||
|
||||
_VALKEY_POOL = None
|
||||
|
||||
|
||||
def _get_valkey_pool():
|
||||
global _VALKEY_POOL
|
||||
if _VALKEY_POOL is None:
|
||||
assert _valkey_mod is not None
|
||||
_VALKEY_POOL = _valkey_mod.ConnectionPool(
|
||||
host=os.getenv('VALKEY_HOST', 'searxng-valkey'),
|
||||
port=int(os.getenv('VALKEY_PORT', 6379)),
|
||||
db=0,
|
||||
decode_responses=True,
|
||||
)
|
||||
return _VALKEY_POOL
|
||||
|
||||
|
||||
def _get_valkey():
|
||||
if not _VALKEY_AVAILABLE or _valkey_mod is None:
|
||||
raise RuntimeError("valkey package not installed")
|
||||
return _valkey_mod.Valkey(connection_pool=_get_valkey_pool())
|
||||
|
||||
|
||||
def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str):
|
||||
chunks_key = f"ai:job:{job_id}:chunks"
|
||||
status_key = f"ai:job:{job_id}:status"
|
||||
conn = None
|
||||
try:
|
||||
vk = _get_valkey()
|
||||
url = endpoint_url
|
||||
res = None
|
||||
for _ in range(3):
|
||||
conn, path = _get_streaming_connection(url)
|
||||
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
|
||||
res = conn.getresponse()
|
||||
if res.status in (301, 302, 307, 308):
|
||||
location = res.getheader('Location', '')
|
||||
res.read()
|
||||
conn.close()
|
||||
conn = None
|
||||
if not location:
|
||||
raise RuntimeError(f"Redirect {res.status} with no Location")
|
||||
url = location if location.startswith('http') else \
|
||||
f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Too many redirects to Ollama endpoint")
|
||||
|
||||
if res.status != 200:
|
||||
body = res.read(1024).decode('utf-8', errors='replace')
|
||||
raise RuntimeError(f"Ollama error {res.status}: {body[:200]}")
|
||||
|
||||
think_depth = 0
|
||||
pending = ''
|
||||
|
||||
while True:
|
||||
raw_line = res.readline()
|
||||
if not raw_line:
|
||||
break
|
||||
line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n')
|
||||
if not line or not line.startswith('data: '):
|
||||
continue
|
||||
data_str = line[6:]
|
||||
if data_str == '[DONE]':
|
||||
break
|
||||
try:
|
||||
obj = json.loads(data_str)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
choices = obj.get('choices', [])
|
||||
if not choices:
|
||||
continue
|
||||
delta = choices[0].get('delta', {})
|
||||
token = delta.get('content') or ''
|
||||
if not token:
|
||||
continue
|
||||
|
||||
pending += token
|
||||
# Filter <think>...</think> blocks, push clean content immediately
|
||||
while True:
|
||||
if think_depth == 0:
|
||||
think_start = pending.find('<think>')
|
||||
if think_start == -1:
|
||||
if pending:
|
||||
vk.rpush(chunks_key, pending)
|
||||
vk.expire(chunks_key, 120)
|
||||
pending = ''
|
||||
break
|
||||
else:
|
||||
before = pending[:think_start]
|
||||
if before:
|
||||
vk.rpush(chunks_key, before)
|
||||
vk.expire(chunks_key, 120)
|
||||
pending = pending[think_start + 7:]
|
||||
think_depth = 1
|
||||
else:
|
||||
think_end = pending.find('</think>')
|
||||
if think_end == -1:
|
||||
break
|
||||
else:
|
||||
pending = pending[think_end + 8:]
|
||||
think_depth = 0
|
||||
|
||||
if think_depth == 0 and pending:
|
||||
vk.rpush(chunks_key, pending)
|
||||
vk.expire(chunks_key, 120)
|
||||
|
||||
vk.rpush(chunks_key, '__DONE__')
|
||||
vk.expire(chunks_key, 120)
|
||||
vk.set(status_key, 'done', ex=120)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True)
|
||||
try:
|
||||
vk2 = _get_valkey()
|
||||
vk2.rpush(chunks_key, f"__ERROR__{e}")
|
||||
vk2.expire(chunks_key, 120)
|
||||
vk2.set(status_key, 'error', ex=120)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if conn:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
PLUGIN_NAME = "AI Answers"
|
||||
DEFAULT_TABS = "general,science,it,news"
|
||||
@@ -577,7 +713,7 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
box.style.display = 'block';
|
||||
|
||||
const controller = new AbortController();
|
||||
let timeoutId = setTimeout(() => controller.abort(), 60000);
|
||||
let timeoutId = setTimeout(() => controller.abort(), 90000);
|
||||
const finalQ = __STREAM_Q__;
|
||||
|
||||
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
|
||||
@@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
return;
|
||||
}
|
||||
|
||||
const fullText = (respJson.text || '').trim();
|
||||
|
||||
if (!fullText) {
|
||||
const jobId = respJson.job_id;
|
||||
if (!jobId) {
|
||||
const cursorErr = data.querySelector('.sxng-cursor');
|
||||
if (cursorErr) cursorErr.remove();
|
||||
const errSpan = document.createElement('span');
|
||||
errSpan.style.color = '#bf616a';
|
||||
errSpan.textContent = 'No response received. Check API configuration and server logs.';
|
||||
errSpan.textContent = 'No job ID returned. Check server logs.';
|
||||
data.appendChild(errSpan);
|
||||
return;
|
||||
}
|
||||
|
||||
let mainText = fullText;
|
||||
const thinkMatch = mainText.match(/^<think>([\s\S]*?)<\/think>\s*/);
|
||||
if (thinkMatch) {
|
||||
const cursorTh = data.querySelector('.sxng-cursor');
|
||||
const details = document.createElement('details');
|
||||
details.className = 'sxng-reasoning';
|
||||
details.innerHTML = '<summary>Thought Process</summary>';
|
||||
const thoughtDiv = document.createElement('div');
|
||||
thoughtDiv.className = 'sxng-thought-content';
|
||||
thoughtDiv.textContent = thinkMatch[1];
|
||||
details.appendChild(thoughtDiv);
|
||||
if (cursorTh) cursorTh.before(details);
|
||||
else data.appendChild(details);
|
||||
mainText = mainText.substring(thinkMatch[0].length);
|
||||
}
|
||||
|
||||
let cursor = data.querySelector('.sxng-cursor');
|
||||
if (!cursor) {
|
||||
cursor = document.createElement('span');
|
||||
@@ -646,6 +765,7 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
}
|
||||
|
||||
let buffer = '';
|
||||
let fullText = '';
|
||||
const flushBuffer = (force = false) => {
|
||||
if (!buffer) return;
|
||||
|
||||
@@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
}
|
||||
};
|
||||
|
||||
let twPos = 0;
|
||||
const twBatch = 4;
|
||||
await new Promise(resolve => {
|
||||
function twTick() {
|
||||
if (twPos >= mainText.length) {
|
||||
flushBuffer(true);
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
const end = Math.min(twPos + twBatch, mainText.length);
|
||||
buffer += mainText.substring(twPos, end);
|
||||
twPos = end;
|
||||
flushBuffer(false);
|
||||
setTimeout(twTick, 8);
|
||||
let offset = 0;
|
||||
const maxPolls = 600;
|
||||
let polls = 0;
|
||||
|
||||
while (polls < maxPolls) {
|
||||
polls++;
|
||||
await new Promise(r => setTimeout(r, 150));
|
||||
|
||||
let statusRes;
|
||||
try {
|
||||
statusRes = await fetch(
|
||||
`${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`,
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
} catch (fetchErr) {
|
||||
if (fetchErr.name === 'AbortError') throw fetchErr;
|
||||
continue;
|
||||
}
|
||||
twTick();
|
||||
});
|
||||
|
||||
if (statusRes.status === 404) {
|
||||
const cursorE = data.querySelector('.sxng-cursor');
|
||||
if (cursorE) cursorE.remove();
|
||||
const expiredSpan = document.createElement('span');
|
||||
expiredSpan.style.color = '#bf616a';
|
||||
expiredSpan.textContent = 'Response expired. Please search again.';
|
||||
data.appendChild(expiredSpan);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!statusRes.ok) continue;
|
||||
|
||||
const statusData = await statusRes.json();
|
||||
|
||||
if (statusData.error) {
|
||||
const cursorE = data.querySelector('.sxng-cursor');
|
||||
if (cursorE) cursorE.remove();
|
||||
const errSpan2 = document.createElement('span');
|
||||
errSpan2.style.color = '#bf616a';
|
||||
errSpan2.textContent = '⚠️ ' + statusData.error;
|
||||
data.appendChild(errSpan2);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const chunk of (statusData.chunks || [])) {
|
||||
fullText += chunk;
|
||||
buffer += chunk;
|
||||
flushBuffer(false);
|
||||
}
|
||||
offset += (statusData.chunks || []).length;
|
||||
|
||||
if (statusData.done) {
|
||||
flushBuffer(true);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (cursor) cursor.remove();
|
||||
|
||||
@@ -738,9 +896,9 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
}
|
||||
}
|
||||
|
||||
renderCitationFooter(mainText, urls, data);
|
||||
renderCitationFooter(fullText, urls, data);
|
||||
|
||||
const collectedResponse = mainText;
|
||||
const collectedResponse = fullText;
|
||||
|
||||
__INTERACTIVE_JS_COMPLETE__
|
||||
|
||||
@@ -748,7 +906,6 @@ FRONTEND_JS_TEMPLATE = r"""
|
||||
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
|
||||
}
|
||||
|
||||
// Save state if this was an initial generation or a regeneration
|
||||
if (arguments.length === 0 && typeof updateState === 'function') {
|
||||
updateState();
|
||||
}
|
||||
@@ -1094,86 +1251,82 @@ class SXNGPlugin(Plugin):
|
||||
{numbered_instructions}
|
||||
</CORE_DIRECTIVES>"""
|
||||
|
||||
def call_ollama():
|
||||
conn = None
|
||||
try:
|
||||
payload_dict = {
|
||||
"model": effective_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": ""},
|
||||
],
|
||||
"stream": False,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
payload = json.dumps(payload_dict)
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
url = self.endpoint_url
|
||||
res = None # type: ignore[assignment]
|
||||
for _ in range(3):
|
||||
conn, path = _get_streaming_connection(url)
|
||||
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
|
||||
res = conn.getresponse()
|
||||
if res.status in (301, 302, 307, 308):
|
||||
location = res.getheader('Location', '')
|
||||
res.read()
|
||||
conn.close()
|
||||
conn = None
|
||||
if not location:
|
||||
return '', f"Redirect {res.status} with no Location header"
|
||||
url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
|
||||
logger.info(f"{PLUGIN_NAME}: Following redirect to {url}")
|
||||
continue
|
||||
break
|
||||
if res.status != 200:
|
||||
body = res.read(1024).decode('utf-8', errors='replace')
|
||||
logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}")
|
||||
return '', f"Ollama error {res.status}"
|
||||
obj = json.loads(res.read().decode('utf-8', errors='replace'))
|
||||
if 'error' in obj:
|
||||
err = obj['error']
|
||||
msg = err.get('message', str(err)) if isinstance(err, dict) else str(err)
|
||||
return '', msg
|
||||
choices = obj.get('choices', [])
|
||||
if not choices:
|
||||
return '', "No choices in Ollama response."
|
||||
message = choices[0].get('message', {})
|
||||
content = message.get('content') or ''
|
||||
reasoning = message.get('reasoning') or message.get('reasoning_content') or ''
|
||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
if not content and reasoning:
|
||||
logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field")
|
||||
lines = reasoning.splitlines()
|
||||
header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$')
|
||||
last_header_idx = -1
|
||||
for i, line in enumerate(lines):
|
||||
if header_re.match(line):
|
||||
last_header_idx = i
|
||||
if last_header_idx >= 0 and last_header_idx < len(lines) - 1:
|
||||
content = '\n'.join(lines[last_header_idx + 1:]).strip()
|
||||
if not content:
|
||||
paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()]
|
||||
content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else ''
|
||||
if reasoning and content:
|
||||
full = (f"<think>\n{reasoning}\n</think>\n\n" if reasoning else "") + content
|
||||
else:
|
||||
full = content
|
||||
full = re.sub(r'<think>.*?</think>', '', full, flags=re.DOTALL).strip()
|
||||
return full, None
|
||||
except Exception as e:
|
||||
logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True)
|
||||
return '', f"Connection Error: {e}"
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16]
|
||||
|
||||
payload_dict = {
|
||||
"model": effective_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": SYSTEM},
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "assistant", "content": ""},
|
||||
],
|
||||
"stream": True,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
stream_payload = json.dumps(payload_dict)
|
||||
stream_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
try:
|
||||
vk = _get_valkey()
|
||||
vk.set(f"ai:job:{job_id}:status", "running", ex=120)
|
||||
except Exception as e:
|
||||
logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True)
|
||||
return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503
|
||||
|
||||
t = threading.Thread(
|
||||
target=stream_to_valkey,
|
||||
args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
|
||||
return jsonify({"job_id": job_id})
|
||||
|
||||
@app.route('/ai-status/<job_id>', methods=['GET'])
|
||||
def ai_status(job_id):
|
||||
token = request.args.get('tk', '')
|
||||
try:
|
||||
ts, sig = token.rsplit('.', 1)
|
||||
expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest()
|
||||
if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC:
|
||||
abort(403)
|
||||
except (ValueError, KeyError, AttributeError):
|
||||
abort(403)
|
||||
|
||||
offset = max(0, int(request.args.get('offset', 0)))
|
||||
chunks_key = f"ai:job:{job_id}:chunks"
|
||||
status_key = f"ai:job:{job_id}:status"
|
||||
|
||||
try:
|
||||
vk = _get_valkey()
|
||||
status = vk.get(status_key)
|
||||
if status is None:
|
||||
return jsonify({"error": "Job not found or expired"}), 404
|
||||
raw_chunks = vk.lrange(chunks_key, offset, -1)
|
||||
except Exception as e:
|
||||
logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True)
|
||||
return jsonify({"error": "Stream service temporarily unavailable"}), 503
|
||||
|
||||
done = False
|
||||
error = None
|
||||
chunks = []
|
||||
for chunk in raw_chunks:
|
||||
if chunk == '__DONE__':
|
||||
done = True
|
||||
break
|
||||
elif chunk.startswith('__ERROR__'):
|
||||
error = chunk[9:]
|
||||
done = True
|
||||
break
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
|
||||
return jsonify({"chunks": chunks, "done": done, "error": error})
|
||||
|
||||
text, error = call_ollama()
|
||||
return jsonify({"text": text, "error": error})
|
||||
return True
|
||||
|
||||
def _fetch_page_text(self, url: str, timeout: int = 5) -> str:
|
||||
|
||||
Reference in New Issue
Block a user