diff --git a/ollama_answers.py b/ollama_answers.py index 9467e54..bf0f57a 100644 --- a/ollama_answers.py +++ b/ollama_answers.py @@ -1,4 +1,4 @@ -import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures +import json, os, logging, base64, time, hashlib, re, http.client, ssl, concurrent.futures, threading from urllib.parse import urlparse from searx import network try: @@ -14,6 +14,14 @@ from markupsafe import Markup logger = logging.getLogger(__name__) +try: + import valkey as _valkey_mod + _VALKEY_AVAILABLE = True +except ImportError: + _VALKEY_AVAILABLE = False + _valkey_mod = None + logger.warning("AI Answers: valkey package not found. Streaming via Valkey unavailable.") + TOKEN_EXPIRY_SEC = 3600 STREAM_CHUNK_SIZE = 512 STREAM_TIMEOUT_SEC = 60 @@ -47,6 +55,134 @@ def _get_streaming_connection(url: str, verify_ssl: bool = True): return conn, path +_VALKEY_POOL = None + + +def _get_valkey_pool(): + global _VALKEY_POOL + if _VALKEY_POOL is None: + assert _valkey_mod is not None + _VALKEY_POOL = _valkey_mod.ConnectionPool( + host=os.getenv('VALKEY_HOST', 'searxng-valkey'), + port=int(os.getenv('VALKEY_PORT', 6379)), + db=0, + decode_responses=True, + ) + return _VALKEY_POOL + + +def _get_valkey(): + if not _VALKEY_AVAILABLE or _valkey_mod is None: + raise RuntimeError("valkey package not installed") + return _valkey_mod.Valkey(connection_pool=_get_valkey_pool()) + + +def stream_to_valkey(job_id: str, payload: str, headers: dict, endpoint_url: str, model: str): + chunks_key = f"ai:job:{job_id}:chunks" + status_key = f"ai:job:{job_id}:status" + conn = None + try: + vk = _get_valkey() + url = endpoint_url + res = None + for _ in range(3): + conn, path = _get_streaming_connection(url) + conn.request("POST", path, body=payload.encode('utf-8'), headers=headers) + res = conn.getresponse() + if res.status in (301, 302, 307, 308): + location = res.getheader('Location', '') + res.read() + conn.close() + conn = None + if not location: + raise RuntimeError(f"Redirect {res.status} with no Location") + url = location if location.startswith('http') else \ + f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}" + continue + break + else: + raise RuntimeError("Too many redirects to Ollama endpoint") + + if res.status != 200: + body = res.read(1024).decode('utf-8', errors='replace') + raise RuntimeError(f"Ollama error {res.status}: {body[:200]}") + + think_depth = 0 + pending = '' + + while True: + raw_line = res.readline() + if not raw_line: + break + line = raw_line.decode('utf-8', errors='replace').rstrip('\r\n') + if not line or not line.startswith('data: '): + continue + data_str = line[6:] + if data_str == '[DONE]': + break + try: + obj = json.loads(data_str) + except (json.JSONDecodeError, ValueError): + continue + choices = obj.get('choices', []) + if not choices: + continue + delta = choices[0].get('delta', {}) + token = delta.get('content') or '' + if not token: + continue + + pending += token + # Filter ... blocks, push clean content immediately + while True: + if think_depth == 0: + think_start = pending.find('') + if think_start == -1: + if pending: + vk.rpush(chunks_key, pending) + vk.expire(chunks_key, 120) + pending = '' + break + else: + before = pending[:think_start] + if before: + vk.rpush(chunks_key, before) + vk.expire(chunks_key, 120) + pending = pending[think_start + 7:] + think_depth = 1 + else: + think_end = pending.find('') + if think_end == -1: + break + else: + pending = pending[think_end + 8:] + think_depth = 0 + + if think_depth == 0 and pending: + vk.rpush(chunks_key, pending) + vk.expire(chunks_key, 120) + + vk.rpush(chunks_key, '__DONE__') + vk.expire(chunks_key, 120) + vk.set(status_key, 'done', ex=120) + + except Exception as e: + logger.error(f"AI Answers: stream_to_valkey error for job {job_id}: {e}", exc_info=True) + try: + vk2 = _get_valkey() + vk2.rpush(chunks_key, f"__ERROR__{e}") + vk2.expire(chunks_key, 120) + vk2.set(status_key, 'error', ex=120) + except Exception: + pass + finally: + if conn: + try: + conn.close() + except Exception: + pass + + PLUGIN_NAME = "AI Answers" DEFAULT_TABS = "general,science,it,news" @@ -577,9 +713,9 @@ FRONTEND_JS_TEMPLATE = r""" box.style.display = 'block'; const controller = new AbortController(); - let timeoutId = setTimeout(() => controller.abort(), 60000); + let timeoutId = setTimeout(() => controller.abort(), 90000); const finalQ = __STREAM_Q__; - + const _selMdl = (document.getElementById('sxng-model-select') || {value: ''}).value; const bodyObj = { q: finalQ, lang: lang_init, context: ctx, tk: tk_init, model: _selMdl__STREAM_BODY__ }; const res = await fetch(script_root + '/ai-stream', { @@ -610,34 +746,17 @@ FRONTEND_JS_TEMPLATE = r""" return; } - const fullText = (respJson.text || '').trim(); - - if (!fullText) { + const jobId = respJson.job_id; + if (!jobId) { const cursorErr = data.querySelector('.sxng-cursor'); if (cursorErr) cursorErr.remove(); const errSpan = document.createElement('span'); errSpan.style.color = '#bf616a'; - errSpan.textContent = 'No response received. Check API configuration and server logs.'; + errSpan.textContent = 'No job ID returned. Check server logs.'; data.appendChild(errSpan); return; } - let mainText = fullText; - const thinkMatch = mainText.match(/^([\s\S]*?)<\/think>\s*/); - if (thinkMatch) { - const cursorTh = data.querySelector('.sxng-cursor'); - const details = document.createElement('details'); - details.className = 'sxng-reasoning'; - details.innerHTML = 'Thought Process'; - const thoughtDiv = document.createElement('div'); - thoughtDiv.className = 'sxng-thought-content'; - thoughtDiv.textContent = thinkMatch[1]; - details.appendChild(thoughtDiv); - if (cursorTh) cursorTh.before(details); - else data.appendChild(details); - mainText = mainText.substring(thinkMatch[0].length); - } - let cursor = data.querySelector('.sxng-cursor'); if (!cursor) { cursor = document.createElement('span'); @@ -646,9 +765,10 @@ FRONTEND_JS_TEMPLATE = r""" } let buffer = ''; + let fullText = ''; const flushBuffer = (force = false) => { if (!buffer) return; - + if (force) { const fragment = renderCitations(buffer, urls); if (cursor) cursor.before(fragment); @@ -659,9 +779,9 @@ FRONTEND_JS_TEMPLATE = r""" while (true) { const match = buffer.match(/(\[\d+(?:,\s*\d+)*\])/); - + if (!match) break; - + const preText = buffer.substring(0, match.index); if (preText) { const s = document.createElement('span'); @@ -695,7 +815,7 @@ FRONTEND_JS_TEMPLATE = r""" cursor.before(s); } buffer = buffer.substring(openIdx); - + if (buffer.length > 50) { const s = document.createElement('span'); s.className = 'sxng-chunk'; @@ -706,23 +826,61 @@ FRONTEND_JS_TEMPLATE = r""" } }; - let twPos = 0; - const twBatch = 4; - await new Promise(resolve => { - function twTick() { - if (twPos >= mainText.length) { - flushBuffer(true); - resolve(); - return; - } - const end = Math.min(twPos + twBatch, mainText.length); - buffer += mainText.substring(twPos, end); - twPos = end; - flushBuffer(false); - setTimeout(twTick, 8); + let offset = 0; + const maxPolls = 600; + let polls = 0; + + while (polls < maxPolls) { + polls++; + await new Promise(r => setTimeout(r, 150)); + + let statusRes; + try { + statusRes = await fetch( + `${script_root}/ai-status/${jobId}?tk=${encodeURIComponent(tk_init)}&offset=${offset}`, + { signal: controller.signal } + ); + } catch (fetchErr) { + if (fetchErr.name === 'AbortError') throw fetchErr; + continue; } - twTick(); - }); + + if (statusRes.status === 404) { + const cursorE = data.querySelector('.sxng-cursor'); + if (cursorE) cursorE.remove(); + const expiredSpan = document.createElement('span'); + expiredSpan.style.color = '#bf616a'; + expiredSpan.textContent = 'Response expired. Please search again.'; + data.appendChild(expiredSpan); + return; + } + + if (!statusRes.ok) continue; + + const statusData = await statusRes.json(); + + if (statusData.error) { + const cursorE = data.querySelector('.sxng-cursor'); + if (cursorE) cursorE.remove(); + const errSpan2 = document.createElement('span'); + errSpan2.style.color = '#bf616a'; + errSpan2.textContent = '⚠️ ' + statusData.error; + data.appendChild(errSpan2); + return; + } + + for (const chunk of (statusData.chunks || [])) { + fullText += chunk; + buffer += chunk; + flushBuffer(false); + } + offset += (statusData.chunks || []).length; + + if (statusData.done) { + flushBuffer(true); + break; + } + } if (cursor) cursor.remove(); @@ -738,17 +896,16 @@ FRONTEND_JS_TEMPLATE = r""" } } - renderCitationFooter(mainText, urls, data); + renderCitationFooter(fullText, urls, data); - const collectedResponse = mainText; + const collectedResponse = fullText; __INTERACTIVE_JS_COMPLETE__ if (collectedResponse) { conversation.turns.push({role: 'assistant', content: collectedResponse.trim(), ts: Date.now()}); } - - // Save state if this was an initial generation or a regeneration + if (arguments.length === 0 && typeof updateState === 'function') { updateState(); } @@ -1094,86 +1251,82 @@ class SXNGPlugin(Plugin): {numbered_instructions} """ - def call_ollama(): - conn = None - try: - payload_dict = { - "model": effective_model, - "messages": [ - {"role": "system", "content": SYSTEM}, - {"role": "user", "content": prompt}, - {"role": "assistant", "content": ""}, - ], - "stream": False, - "max_tokens": self.max_tokens, - "temperature": self.temperature, - } - payload = json.dumps(payload_dict) - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - url = self.endpoint_url - res = None # type: ignore[assignment] - for _ in range(3): - conn, path = _get_streaming_connection(url) - conn.request("POST", path, body=payload.encode('utf-8'), headers=headers) - res = conn.getresponse() - if res.status in (301, 302, 307, 308): - location = res.getheader('Location', '') - res.read() - conn.close() - conn = None - if not location: - return '', f"Redirect {res.status} with no Location header" - url = location if location.startswith('http') else f"{urlparse(url).scheme}://{urlparse(url).netloc}{location}" - logger.info(f"{PLUGIN_NAME}: Following redirect to {url}") - continue - break - if res.status != 200: - body = res.read(1024).decode('utf-8', errors='replace') - logger.error(f"{PLUGIN_NAME}: Ollama {res.status}: {body}") - return '', f"Ollama error {res.status}" - obj = json.loads(res.read().decode('utf-8', errors='replace')) - if 'error' in obj: - err = obj['error'] - msg = err.get('message', str(err)) if isinstance(err, dict) else str(err) - return '', msg - choices = obj.get('choices', []) - if not choices: - return '', "No choices in Ollama response." - message = choices[0].get('message', {}) - content = message.get('content') or '' - reasoning = message.get('reasoning') or message.get('reasoning_content') or '' - content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() - if not content and reasoning: - logger.warning(f"{PLUGIN_NAME}: content empty, extracting from reasoning field") - lines = reasoning.splitlines() - header_re = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$') - last_header_idx = -1 - for i, line in enumerate(lines): - if header_re.match(line): - last_header_idx = i - if last_header_idx >= 0 and last_header_idx < len(lines) - 1: - content = '\n'.join(lines[last_header_idx + 1:]).strip() - if not content: - paragraphs = [p.strip() for p in reasoning.split('\n\n') if p.strip()] - content = '\n\n'.join(paragraphs[-2:]) if len(paragraphs) >= 2 else paragraphs[-1] if paragraphs else '' - if reasoning and content: - full = (f"\n{reasoning}\n\n\n" if reasoning else "") + content - else: - full = content - full = re.sub(r'.*?', '', full, flags=re.DOTALL).strip() - return full, None - except Exception as e: - logger.error(f"{PLUGIN_NAME}: Ollama call error: {e}", exc_info=True) - return '', f"Connection Error: {e}" - finally: - if conn: - conn.close() + job_id = hashlib.sha256(f"{time.time()}{q}".encode()).hexdigest()[:16] + + payload_dict = { + "model": effective_model, + "messages": [ + {"role": "system", "content": SYSTEM}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": ""}, + ], + "stream": True, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + } + stream_payload = json.dumps(payload_dict) + stream_headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + } + + try: + vk = _get_valkey() + vk.set(f"ai:job:{job_id}:status", "running", ex=120) + except Exception as e: + logger.error(f"{PLUGIN_NAME}: Valkey unavailable: {e}", exc_info=True) + return jsonify({"error": "Streaming service unavailable (Valkey connection failed)."}), 503 + + t = threading.Thread( + target=stream_to_valkey, + args=(job_id, stream_payload, stream_headers, self.endpoint_url, effective_model), + daemon=True, + ) + t.start() + + return jsonify({"job_id": job_id}) + + @app.route('/ai-status/', methods=['GET']) + def ai_status(job_id): + token = request.args.get('tk', '') + try: + ts, sig = token.rsplit('.', 1) + expected = hashlib.sha256(f"{ts}{self.secret}".encode()).hexdigest() + if sig != expected or (time.time() - float(ts)) > TOKEN_EXPIRY_SEC: + abort(403) + except (ValueError, KeyError, AttributeError): + abort(403) + + offset = max(0, int(request.args.get('offset', 0))) + chunks_key = f"ai:job:{job_id}:chunks" + status_key = f"ai:job:{job_id}:status" + + try: + vk = _get_valkey() + status = vk.get(status_key) + if status is None: + return jsonify({"error": "Job not found or expired"}), 404 + raw_chunks = vk.lrange(chunks_key, offset, -1) + except Exception as e: + logger.error(f"{PLUGIN_NAME}: Valkey error in /ai-status: {e}", exc_info=True) + return jsonify({"error": "Stream service temporarily unavailable"}), 503 + + done = False + error = None + chunks = [] + for chunk in raw_chunks: + if chunk == '__DONE__': + done = True + break + elif chunk.startswith('__ERROR__'): + error = chunk[9:] + done = True + break + else: + chunks.append(chunk) + + return jsonify({"chunks": chunks, "done": done, "error": error}) - text, error = call_ollama() - return jsonify({"text": text, "error": error}) return True def _fetch_page_text(self, url: str, timeout: int = 5) -> str: