Adding better AI response streaming logic
This commit is contained in:
+273
-120
@@ -1,4 +1,4 @@
|
|||||||
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures
|
import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from searx import network
|
from searx import network
|
||||||
try:
|
try:
|
||||||
@@ -14,6 +14,14 @@ from markupsafe import Markup
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import valkey as _valkey_mod
|
||||||
|
_VALKEY_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_VALKEY_AVAILABLE = False
|
||||||
|
_valkey_mod = None
|
||||||
|
logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.")
|
||||||
|
|
||||||
TOKEN_EXPIRY_SEC = 3600
|
TOKEN_EXPIRY_SEC = 3600
|
||||||
STREAM_CHUNK_SIZE = 512
|
STREAM_CHUNK_SIZE = 512
|
||||||
STREAM_TIMEOUT_SEC = 60
|
STREAM_TIMEOUT_SEC = 60
|
||||||
@@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True):
|
|||||||
return conn, path
|
return conn, path
|
||||||
|
|
||||||
|
|
||||||
|
_VALKEY_POOL = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_valkey_pool():
|
||||||
|
global _VALKEY_POOL
|
||||||
|
if _VALKEY_POOL is None:
|
||||||
|
assert _valkey_mod is not None
|
||||||
|
_VALKEY_POOL = _valkey_mod.ConnectionPool(
|
||||||
|
host=os.getenv('VALKEY_HOST', 'searxng-valkey'),
|
||||||
|
port=int(os.getenv('VALKEY_PORT', 6379)),
|
||||||
|
db=0,
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
return _VALKEY_POOL
|
||||||
|
|
||||||
|
|
||||||
|
def _get_valkey():
|
||||||
|
if not _VALKEY_AVAILABLE or _valkey_mod is None:
|
||||||
|
raise RuntimeError("valkey package not installed")
|
||||||
|
return _valkey_mod.Valkey(connection_pool=_get_valkey_pool())
|
||||||
|
|
||||||
|
|
||||||
|
def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str):
|
||||||
|
chunks_key = f"ai:job:{job_id}:chunks"
|
||||||
|
status_key = f"ai:job:{job_id}:status"
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
vk = _get_valkey()
|
||||||
|
url = endpoint_url
|
||||||
|
res = None
|
||||||
|
for _ in range(3):
|
||||||
|
conn, path = _get_streaming_connection(url)
|
||||||
|
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
|
||||||
|
res = conn.getresponse()
|
||||||
|
if res.status in (301, 302, 307, 308):
|
||||||
|
location = res.getheader('Location', '')
|
||||||
|
res.read()
|
||||||
|
conn.close()
|
||||||
|
conn = None
|
||||||
|
if not location:
|
||||||
|
raise RuntimeError(f"Redirect {res.status} with no Location")
|
||||||
|
url = location if location.startswith('http') else \
|
||||||
|
f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Too many redirects to Ollama endpoint")
|
||||||
|
|
||||||
|
if res.status != 200:
|
||||||
|
body = res.read(1024).decode('utf-8', errors='replace')
|
||||||
|
raise RuntimeError(f"Ollama error {res.status}: {body[:200]}")
|
||||||
|
|
||||||
|
think_depth = 0
|
||||||
|
pending = ''
|
||||||
|
|
||||||
|
while True:
|
||||||
|
raw_line = res.readline()
|
||||||
|
if not raw_line:
|
||||||
|
break
|
||||||
|
line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n')
|
||||||
|
if not line or not line.startswith('data: '):
|
||||||
|
continue
|
||||||
|
data_str = line[6:]
|
||||||
|
if data_str == '[DONE]':
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
obj = json.loads(data_str)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
choices = obj.get('choices', [])
|
||||||
|
if not choices:
|
||||||
|
continue
|
||||||
|
delta = choices[0].get('delta', {})
|
||||||
|
token = delta.get('content') or ''
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pending += token
|
||||||
|
# Filter <think>...</think> blocks, push clean content immediately
|
||||||
|
while True:
|
||||||
|
if think_depth == 0:
|
||||||
|
think_start = pending.find('<think>')
|
||||||
|
if think_start == -1:
|
||||||
|
if pending:
|
||||||
|
vk.rpush(chunks_key, pending)
|
||||||
|
vk.expire(chunks_key, 120)
|
||||||
|
pending = ''
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
before = pending[:think_start]
|
||||||
|
if before:
|
||||||
|
vk.rpush(chunks_key, before)
|
||||||
|
vk.expire(chunks_key, 120)
|
||||||
|
pending = pending[think_start + 7:]
|
||||||
|
think_depth = 1
|
||||||
|
else:
|
||||||
|
think_end = pending.find('</think>')
|
||||||
|
if think_end == -1:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
pending = pending[think_end + 8:]
|
||||||
|
think_depth = 0
|
||||||
|
|
||||||
|
if think_depth == 0 and pending:
|
||||||
|
vk.rpush(chunks_key, pending)
|
||||||
|
vk.expire(chunks_key, 120)
|
||||||
|
|
||||||
|
vk.rpush(chunks_key, '__DONE__')
|
||||||
|
vk.expire(chunks_key, 120)
|
||||||
|
vk.set(status_key, 'done', ex=120)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
vk2 = _get_valkey()
|
||||||
|
vk2.rpush(chunks_key, f"__ERROR__{e}")
|
||||||
|
vk2.expire(chunks_key, 120)
|
||||||
|
vk2.set(status_key, 'error', ex=120)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PLUGIN_NAME = "AI Answers"
|
PLUGIN_NAME = "AI Answers"
|
||||||
DEFAULT_TABS = "general,science,it,news"
|
DEFAULT_TABS = "general,science,it,news"
|
||||||
@@ -577,7 +713,7 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
box.style.display = 'block';
|
box.style.display = 'block';
|
||||||
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
let timeoutId = setTimeout(() => controller.abort(), 60000);
|
let timeoutId = setTimeout(() => controller.abort(), 90000);
|
||||||
const finalQ = __STREAM_Q__;
|
const finalQ = __STREAM_Q__;
|
||||||
|
|
||||||
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
|
const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value;
|
||||||
@@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const fullText = (respJson.text || '').trim();
|
const jobId = respJson.job_id;
|
||||||
|
if (!jobId) {
|
||||||
if (!fullText) {
|
|
||||||
const cursorErr = data.querySelector('.sxng-cursor');
|
const cursorErr = data.querySelector('.sxng-cursor');
|
||||||
if (cursorErr) cursorErr.remove();
|
if (cursorErr) cursorErr.remove();
|
||||||
const errSpan = document.createElement('span');
|
const errSpan = document.createElement('span');
|
||||||
errSpan.style.color = '#bf616a';
|
errSpan.style.color = '#bf616a';
|
||||||
errSpan.textContent = 'No response received. Check API configuration and server logs.';
|
errSpan.textContent = 'No job ID returned. Check server logs.';
|
||||||
data.appendChild(errSpan);
|
data.appendChild(errSpan);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mainText = fullText;
|
|
||||||
const thinkMatch = mainText.match(/^<think>([\s\S]*?)<\/think>\s*/);
|
|
||||||
if (thinkMatch) {
|
|
||||||
const cursorTh = data.querySelector('.sxng-cursor');
|
|
||||||
const details = document.createElement('details');
|
|
||||||
details.className = 'sxng-reasoning';
|
|
||||||
details.innerHTML = '<summary>Thought Process</summary>';
|
|
||||||
const thoughtDiv = document.createElement('div');
|
|
||||||
thoughtDiv.className = 'sxng-thought-content';
|
|
||||||
thoughtDiv.textContent = thinkMatch[1];
|
|
||||||
details.appendChild(thoughtDiv);
|
|
||||||
if (cursorTh) cursorTh.before(details);
|
|
||||||
else data.appendChild(details);
|
|
||||||
mainText = mainText.substring(thinkMatch[0].length);
|
|
||||||
}
|
|
||||||
|
|
||||||
let cursor = data.querySelector('.sxng-cursor');
|
let cursor = data.querySelector('.sxng-cursor');
|
||||||
if (!cursor) {
|
if (!cursor) {
|
||||||
cursor = document.createElement('span');
|
cursor = document.createElement('span');
|
||||||
@@ -646,6 +765,7 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
}
|
}
|
||||||
|
|
||||||
let buffer = '';
|
let buffer = '';
|
||||||
|
let fullText = '';
|
||||||
const flushBuffer = (force = false) => {
|
const flushBuffer = (force = false) => {
|
||||||
if (!buffer) return;
|
if (!buffer) return;
|
||||||
|
|
||||||
@@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let twPos = 0;
|
let offset = 0;
|
||||||
const twBatch = 4;
|
const maxPolls = 600;
|
||||||
await new Promise(resolve => {
|
let polls = 0;
|
||||||
function twTick() {
|
|
||||||
if (twPos >= mainText.length) {
|
while (polls < maxPolls) {
|
||||||
flushBuffer(true);
|
polls++;
|
||||||
resolve();
|
await new Promise(r => setTimeout(r, 150));
|
||||||
return;
|
|
||||||
}
|
let statusRes;
|
||||||
const end = Math.min(twPos + twBatch, mainText.length);
|
try {
|
||||||
buffer += mainText.substring(twPos, end);
|
statusRes = await fetch(
|
||||||
twPos = end;
|
`${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`,
|
||||||
flushBuffer(false);
|
{ signal: controller.signal }
|
||||||
setTimeout(twTick, 8);
|
);
|
||||||
|
} catch (fetchErr) {
|
||||||
|
if (fetchErr.name === 'AbortError') throw fetchErr;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
twTick();
|
|
||||||
});
|
if (statusRes.status === 404) {
|
||||||
|
const cursorE = data.querySelector('.sxng-cursor');
|
||||||
|
if (cursorE) cursorE.remove();
|
||||||
|
const expiredSpan = document.createElement('span');
|
||||||
|
expiredSpan.style.color = '#bf616a';
|
||||||
|
expiredSpan.textContent = 'Response expired. Please search again.';
|
||||||
|
data.appendChild(expiredSpan);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!statusRes.ok) continue;
|
||||||
|
|
||||||
|
const statusData = await statusRes.json();
|
||||||
|
|
||||||
|
if (statusData.error) {
|
||||||
|
const cursorE = data.querySelector('.sxng-cursor');
|
||||||
|
if (cursorE) cursorE.remove();
|
||||||
|
const errSpan2 = document.createElement('span');
|
||||||
|
errSpan2.style.color = '#bf616a';
|
||||||
|
errSpan2.textContent = '⚠️ ' + statusData.error;
|
||||||
|
data.appendChild(errSpan2);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const chunk of (statusData.chunks || [])) {
|
||||||
|
fullText += chunk;
|
||||||
|
buffer += chunk;
|
||||||
|
flushBuffer(false);
|
||||||
|
}
|
||||||
|
offset += (statusData.chunks || []).length;
|
||||||
|
|
||||||
|
if (statusData.done) {
|
||||||
|
flushBuffer(true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (cursor) cursor.remove();
|
if (cursor) cursor.remove();
|
||||||
|
|
||||||
@@ -738,9 +896,9 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
renderCitationFooter(mainText, urls, data);
|
renderCitationFooter(fullText, urls, data);
|
||||||
|
|
||||||
const collectedResponse = mainText;
|
const collectedResponse = fullText;
|
||||||
|
|
||||||
__INTERACTIVE_JS_COMPLETE__
|
__INTERACTIVE_JS_COMPLETE__
|
||||||
|
|
||||||
@@ -748,7 +906,6 @@ FRONTEND_JS_TEMPLATE = r"""
|
|||||||
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
|
conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save state if this was an initial generation or a regeneration
|
|
||||||
if (arguments.length === 0 && typeof updateState === 'function') {
|
if (arguments.length === 0 && typeof updateState === 'function') {
|
||||||
updateState();
|
updateState();
|
||||||
}
|
}
|
||||||
@@ -1094,86 +1251,82 @@ class SXNGPlugin(Plugin):
|
|||||||
{numbered_instructions}
|
{numbered_instructions}
|
||||||
</CORE_DIRECTIVES>"""
|
</CORE_DIRECTIVES>"""
|
||||||
|
|
||||||
def call_ollama():
|
job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16]
|
||||||
conn = None
|
|
||||||
try:
|
payload_dict = {
|
||||||
payload_dict = {
|
"model": effective_model,
|
||||||
"model": effective_model,
|
"messages": [
|
||||||
"messages": [
|
{"role": "system", "content": SYSTEM},
|
||||||
{"role": "system", "content": SYSTEM},
|
{"role": "user", "content": prompt},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "assistant", "content": ""},
|
||||||
{"role": "assistant", "content": ""},
|
],
|
||||||
],
|
"stream": True,
|
||||||
"stream": False,
|
"max_tokens": self.max_tokens,
|
||||||
"max_tokens": self.max_tokens,
|
"temperature": self.temperature,
|
||||||
"temperature": self.temperature,
|
}
|
||||||
}
|
stream_payload = json.dumps(payload_dict)
|
||||||
payload = json.dumps(payload_dict)
|
stream_headers = {
|
||||||
headers = {
|
"Content-Type": "application/json",
|
||||||
"Content-Type": "application/json",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
}
|
||||||
}
|
|
||||||
url = self.endpoint_url
|
try:
|
||||||
res = None # type: ignore[assignment]
|
vk = _get_valkey()
|
||||||
for _ in range(3):
|
vk.set(f"ai:job:{job_id}:status", "running", ex=120)
|
||||||
conn, path = _get_streaming_connection(url)
|
except Exception as e:
|
||||||
conn.request("POST", path, body=payload.encode('utf-8'), headers=headers)
|
logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True)
|
||||||
res = conn.getresponse()
|
return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503
|
||||||
if res.status in (301, 302, 307, 308):
|
|
||||||
location = res.getheader('Location', '')
|
t = threading.Thread(
|
||||||
res.read()
|
target=stream_to_valkey,
|
||||||
conn.close()
|
args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model),
|
||||||
conn = None
|
daemon=True,
|
||||||
if not location:
|
)
|
||||||
return '', f"Redirect {res.status} with no Location header"
|
t.start()
|
||||||
url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}"
|
|
||||||
logger.info(f"{PLUGIN_NAME}: Following redirect to {url}")
|
return jsonify({"job_id": job_id})
|
||||||
continue
|
|
||||||
break
|
@app.route('/ai-status/<job_id>', methods=['GET'])
|
||||||
if res.status != 200:
|
def ai_status(job_id):
|
||||||
body = res.read(1024).decode('utf-8', errors='replace')
|
token = request.args.get('tk', '')
|
||||||
logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}")
|
try:
|
||||||
return '', f"Ollama error {res.status}"
|
ts, sig = token.rsplit('.', 1)
|
||||||
obj = json.loads(res.read().decode('utf-8', errors='replace'))
|
expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest()
|
||||||
if 'error' in obj:
|
if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC:
|
||||||
err = obj['error']
|
abort(403)
|
||||||
msg = err.get('message', str(err)) if isinstance(err, dict) else str(err)
|
except (ValueError, KeyError, AttributeError):
|
||||||
return '', msg
|
abort(403)
|
||||||
choices = obj.get('choices', [])
|
|
||||||
if not choices:
|
offset = max(0, int(request.args.get('offset', 0)))
|
||||||
return '', "No choices in Ollama response."
|
chunks_key = f"ai:job:{job_id}:chunks"
|
||||||
message = choices[0].get('message', {})
|
status_key = f"ai:job:{job_id}:status"
|
||||||
content = message.get('content') or ''
|
|
||||||
reasoning = message.get('reasoning') or message.get('reasoning_content') or ''
|
try:
|
||||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
vk = _get_valkey()
|
||||||
if not content and reasoning:
|
status = vk.get(status_key)
|
||||||
logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field")
|
if status is None:
|
||||||
lines = reasoning.splitlines()
|
return jsonify({"error": "Job not found or expired"}), 404
|
||||||
header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$')
|
raw_chunks = vk.lrange(chunks_key, offset, -1)
|
||||||
last_header_idx = -1
|
except Exception as e:
|
||||||
for i, line in enumerate(lines):
|
logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True)
|
||||||
if header_re.match(line):
|
return jsonify({"error": "Stream service temporarily unavailable"}), 503
|
||||||
last_header_idx = i
|
|
||||||
if last_header_idx >= 0 and last_header_idx < len(lines) - 1:
|
done = False
|
||||||
content = '\n'.join(lines[last_header_idx + 1:]).strip()
|
error = None
|
||||||
if not content:
|
chunks = []
|
||||||
paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()]
|
for chunk in raw_chunks:
|
||||||
content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else ''
|
if chunk == '__DONE__':
|
||||||
if reasoning and content:
|
done = True
|
||||||
full = (f"<think>\n{reasoning}\n</think>\n\n" if reasoning else "") + content
|
break
|
||||||
else:
|
elif chunk.startswith('__ERROR__'):
|
||||||
full = content
|
error = chunk[9:]
|
||||||
full = re.sub(r'<think>.*?</think>', '', full, flags=re.DOTALL).strip()
|
done = True
|
||||||
return full, None
|
break
|
||||||
except Exception as e:
|
else:
|
||||||
logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True)
|
chunks.append(chunk)
|
||||||
return '', f"Connection Error: {e}"
|
|
||||||
finally:
|
return jsonify({"chunks": chunks, "done": done, "error": error})
|
||||||
if conn:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
text, error = call_ollama()
|
|
||||||
return jsonify({"text": text, "error": error})
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _fetch_page_text(self, url: str, timeout: int = 5) -> str:
|
def _fetch_page_text(self, url: str, timeout: int = 5) -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user