diff --git a/ai_answers.py b/ai_answers.py index 10f7317..7a5357d 100644 --- a/ai_answers.py +++ b/ai_answers.py @@ -45,6 +45,17 @@ def _get_streaming_connection(url: str): PLUGIN_NAME = "AI Answers" DEFAULT_TABS = "general,science,it,news" +PROVIDER_PRESETS = { + 'openai': {'url': 'https://api.openai.com/v1/chat/completions', 'model': 'gpt-4o-mini'}, + 'openrouter': {'url': 'https://openrouter.ai/api/v1/chat/completions', 'model': 'google/gemma-3-27b-it:free'}, + 'ollama': {'url': 'http://localhost:11434/v1/chat/completions', 'model': 'llama3.2'}, + 'localai': {'url': 'http://localhost:8080/v1/chat/completions', 'model': 'gpt-4'}, + 'lmstudio': {'url': 'http://localhost:1234/v1/chat/completions', 'model': 'local-model'}, + 'gemini': {'url': 'https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent', 'model': 'gemma-3-27b-it'}, + 'azure': {'url': None, 'model': 'azure-deployment'}, + 'huggingface': {'url': 'https://api-inference.huggingface.co/models/{model}/v1/chat/completions', 'model': 'meta-llama/Meta-Llama-3-8B-Instruct'} +} + # UI assets INTERACTIVE_CSS = ''' @@ -734,24 +745,89 @@ class SXNGPlugin(Plugin): self.info = PluginInfo( id=self.id, name=gettext(f"{PLUGIN_NAME} Plugin"), - description=gettext("Local AI search overviews powered by Ollama."), + description=gettext("Live AI search answers using LLM providers."), preference_section="general", ) self._load_config() + def _ollama_unload_model(self) -> None: + try: + if self.provider != 'ollama': + return + if not getattr(self, 'ollama_unload_after', False): + return + unload_url = (getattr(self, 'ollama_unload_url', '') or '').strip() + if not unload_url: + return + + conn = None + try: + conn, path = _get_streaming_connection(unload_url) + conn.timeout = 2.0 + payload = json.dumps({ + "model": self.model, + "messages": [], + "keep_alive": 0 + }) + headers = {"Content-Type": "application/json"} + if self.api_key and self.api_key not in ('none', 'ollama'): + headers["Authorization"] = f"Bearer {self.api_key}" + conn.request("POST", path, body=payload, headers=headers) + res = conn.getresponse() + res.read() + if res.status >= 400: + logger.warning(f"{PLUGIN_NAME}: Ollama unload failed: {res.status} {res.reason}") + finally: + if conn: + conn.close() + except Exception as e: + logger.warning(f"{PLUGIN_NAME}: Ollama unload error: {e}") + def _load_config(self): self.interactive = os.getenv('LLM_INTERACTIVE', 'true').lower().strip() in ('true', '1', 'yes', 'on') self.question_mark_required = os.getenv('LLM_QUESTION_MARK_REQUIRED', 'false').lower().strip() in ('true', '1', 'yes', 'on') + raw_provider = os.getenv('LLM_PROVIDER', '').lower().strip() + + raw_url = os.getenv('LLM_URL', '').strip() + if not raw_provider and raw_url: + url_lower = raw_url.lower() + if 'openai.com' in url_lower: + raw_provider = 'openai' + elif 'openrouter.ai' in url_lower: + raw_provider = 'openrouter' + elif ':11434' in url_lower: + raw_provider = 'ollama' + elif 'generativelanguage.googleapis.com' in url_lower: + raw_provider = 'gemini' + elif 'openai.azure.com' in url_lower or '.azure.com' in url_lower: + raw_provider = 'azure' + elif 'huggingface.co' in url_lower: + raw_provider = 'huggingface' + else: + raw_provider = 'openai' + logger.info(f"{PLUGIN_NAME}: Using OpenAI-compatible mode for custom URL") + + if not raw_provider: + self.provider = '' + self.model = '' + self.is_gemini = False + self.api_key = '' + return + + if raw_provider not in PROVIDER_PRESETS: + logger.warning(f"{PLUGIN_NAME}: Unknown provider '{raw_provider}', falling back to 'openai'") + self.provider = raw_provider if raw_provider in PROVIDER_PRESETS else 'openai' + self.is_gemini = (self.provider == 'gemini') + preset = PROVIDER_PRESETS[self.provider] - raw_url = os.getenv('LLM_URL', 'http://ollama:11434/v1/chat/completions').strip() - if not raw_url.startswith(('http://', 'https://')): - raw_url = f"http://{raw_url}" - self.endpoint_url = raw_url + self.api_key = os.getenv('LLM_KEY', '') + if not self.api_key and self.provider in ('ollama', 'localai', 'lmstudio'): + self.api_key = 'none' + self.api_key = self.api_key.strip() - self.api_key = 'ollama' - self.model = os.getenv('LLM_MODEL', 'llama3.2').strip() + self.model = os.getenv('LLM_MODEL', preset['model']).strip() try: self.max_tokens = max(1, int(os.getenv('LLM_MAX_TOKENS', 200))) @@ -775,10 +851,31 @@ class SXNGPlugin(Plugin): self.context_shallow_count = 15 self.allowed_tabs = set(t.strip() for t in os.getenv('LLM_TABS', DEFAULT_TABS).split(',')) - + + preset_url = preset['url'] + if preset_url and '{model}' in preset_url: + preset_url = preset_url.format(model=self.model) + + raw_url = os.getenv('LLM_URL', '').strip() or preset_url + if not raw_url.startswith(('http://', 'https://')): + raw_url = f"https://{raw_url}" + self.endpoint_url = raw_url + + self.ollama_unload_after = os.getenv('LLM_OLLAMA_UNLOAD_AFTER', 'false').lower().strip() in ('true', '1', 'yes', 'on') + self.ollama_unload_url = '' + if self.provider == 'ollama' and self.ollama_unload_after: + try: + p = urlparse(self.endpoint_url) + scheme = p.scheme or 'http' + host = p.hostname or 'localhost' + port = p.port + netloc = f"{host}:{port}" if port else host + self.ollama_unload_url = f"{scheme}://{netloc}/api/chat" + except Exception: + self.ollama_unload_url = "http://localhost:11434/api/chat" server_secret = settings.get('server', {}).get('secret_key', '') self.secret = hashlib.sha256(f"ai_answers_{server_secret}".encode()).hexdigest() - + self.system_prompt = os.getenv('LLM_SYSTEM_PROMPT', '').strip() def _parse_aux_results(self, raw_results, raw_infoboxes, raw_answers): @@ -826,6 +923,9 @@ class SXNGPlugin(Plugin): def init(self, app): + if not self.provider: + return + @app.route('/ai-auxiliary-search', methods=['POST']) def ai_auxiliary_search(): if not self.api_key: @@ -915,6 +1015,9 @@ class SXNGPlugin(Plugin): except (ValueError, KeyError, AttributeError): abort(403) + if self.provider != 'ollama': + return jsonify({'models': [self.model] if self.model else []}) + conn = None try: p = urlparse(self.endpoint_url) @@ -1019,6 +1122,111 @@ class SXNGPlugin(Plugin): {numbered_instructions} """ + def call_gemini(): + base = self.endpoint_url.replace('streamGenerateContent', 'generateContent') + url = f"{base}&key={self.api_key}" if '?' in base else f"{base}?key={self.api_key}" + conn = None + try: + conn, path = _get_streaming_connection(url) + payload = json.dumps({ + "contents": [{"parts": [{"text": prompt}]}], + "generationConfig": {"maxOutputTokens": min(self.max_tokens * 4, 8192), "temperature": self.temperature} + }) + conn.request("POST", path, body=payload.encode('utf-8'), headers={"Content-Type": "application/json"}) + res = conn.getresponse() + if res.status != 200: + body = res.read(2048).decode('utf-8', errors='replace')[:500] + logger.error(f"{PLUGIN_NAME}: Gemini API {res.status}: {body}") + return '', f"API error {res.status}. Check server logs." + obj = json.loads(res.read().decode('utf-8', errors='replace')) + if obj.get('promptFeedback', {}).get('blockReason'): + return '', f"Gemini blocked prompt: {obj['promptFeedback']['blockReason']}" + candidates = obj.get('candidates', []) + if not candidates: + return '', "No candidates in Gemini response." + first = candidates[0] + if first.get('finishReason') == 'SAFETY': + return '', "Gemini stopped generation due to safety filters." + parts = first.get('content', {}).get('parts', []) + text = ''.join(p.get('text', '') for p in parts if isinstance(p, dict)) + return text, None + except Exception as e: + logger.error(f"{PLUGIN_NAME}: Gemini call error: {e}", exc_info=True) + return '', f"Connection Error: {e}" + finally: + if conn: conn.close() + + def call_openai_compatible(): + conn = None + try: + conn, path = _get_streaming_connection(self.endpoint_url) + payload_dict = { + "model": effective_model, + "messages": [ + {"role": "system", "content": SYSTEM}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": ""}, + ], + "stream": False, + "max_tokens": self.max_tokens, + "temperature": self.temperature + } + payload = json.dumps(payload_dict) + headers = { + "Content-Type": "application/json", + "HTTP-Referer": "https://github.com/searxng/searxng", + "X-Title": "SearXNG" + } + if self.provider == 'azure': + headers['api-key'] = self.api_key + else: + headers['Authorization'] = f"Bearer {self.api_key}" + conn.request("POST", path, body=payload.encode('utf-8'), headers=headers) + res = conn.getresponse() + if res.status != 200: + body = res.read(2048).decode('utf-8', errors='replace')[:500] + logger.error(f"{PLUGIN_NAME}: {self.provider} API {res.status}: {body}") + return '', f"API error {res.status}. Check server logs." + obj = json.loads(res.read().decode('utf-8', errors='replace')) + if "error" in obj: + err = obj["error"] + msg = err.get("message", str(err)) if isinstance(err, dict) else str(err) + return '', f"API Error: {msg}" + choices = obj.get("choices", []) + if not choices: + return '', "No choices in API response." + message = choices[0].get("message", {}) + content = re.sub(r'.*?', '', message.get("content") or "", flags=re.DOTALL).strip() + reasoning = message.get("reasoning") or message.get("reasoning_content") or "" + if not content and reasoning: + logger.warning(f"{PLUGIN_NAME}: {self.provider} returned empty content; extracting answer from reasoning field") + header_pat = re.compile(r'^\s*\*?\*?[A-Z][^:]{0,40}:\*?\*?\s*$', re.MULTILINE) + matches = list(header_pat.finditer(reasoning)) + if matches: + answer = reasoning[matches[-1].end():].strip() + else: + paras = [p.strip() for p in re.split(r'\n{2,}', reasoning) if p.strip()] + answer = paras[-1] if paras else reasoning.strip() + full = answer + else: + full = (f"\n{reasoning}\n\n\n" if reasoning else "") + content + full = re.sub(r'.*?', '', full, flags=re.DOTALL).strip() + return full, None + except Exception as e: + logger.error(f"{PLUGIN_NAME}: {self.provider} call error: {e}", exc_info=True) + return '', f"Connection Error: {e}" + finally: + if conn: conn.close() + + call_fn = call_gemini if self.is_gemini else call_openai_compatible + text, error = call_fn() + + if self.provider == 'ollama' and getattr(self, 'ollama_unload_after', False): + self._ollama_unload_model() + + return jsonify({"text": text, "error": error}) + return True + def _assemble_context(self, clean_results, infoboxes, answers, offset=0) -> tuple[str, list]: """Builds context string from normalized search data. Returns (context_str, urls).""" context_parts = []