diff --git a/gemini_flash.py b/gemini_flash.py index 00080f5..0045360 100644 --- a/gemini_flash.py +++ b/gemini_flash.py @@ -1,5 +1,5 @@ -import json, http.client, ssl, os, logging, base64 -from flask import Response, request +import json, http.client, ssl, os, logging, base64, secrets, time +from flask import Response, request, abort from searx.plugins import Plugin, PluginInfo from searx.result_types import EngineResults from flask_babel import gettext @@ -24,11 +24,22 @@ class SXNGPlugin(Plugin): self.max_tokens = int(os.getenv('GEMINI_MAX_TOKENS', 500)) self.temperature = float(os.getenv('GEMINI_TEMPERATURE', 0.2)) self.base_url = os.getenv('OPENROUTER_BASE_URL', 'openrouter.ai') + self.valid_tokens = {} def init(self, app): @app.route('/gemini-stream', methods=['POST']) def g_stream(): data = request.json or {} + token = data.get('tk', '') + + # Maintenance: Token validation & cleanup + now = time.time() + self.valid_tokens = {k: v for k, v in self.valid_tokens.items() if v > now} + + if token not in self.valid_tokens: + abort(403) + del self.valid_tokens[token] + context_text = data.get('context', '') q = data.get('q', '') @@ -61,7 +72,6 @@ class SXNGPlugin(Plugin): chunk = res.read(128) if not chunk: break buffer += chunk.decode('utf-8') - while buffer: buffer = buffer.lstrip() if not buffer: break @@ -99,6 +109,7 @@ class SXNGPlugin(Plugin): res = conn.getresponse() if res.status != 200: return + decoder = json.JSONDecoder() buffer = "" while True: chunk = res.read(128) @@ -106,13 +117,12 @@ class SXNGPlugin(Plugin): buffer += chunk.decode('utf-8') while "\n" in buffer: line, buffer = buffer.split("\n", 1) - line = line.strip() if line.startswith("data: "): data_str = line[6:].strip() if data_str == "[DONE]": return try: - data_json = json.loads(data_str) - content = data_json.get("choices", [{}])[0].get("delta", {}).get("content", "") + obj, _ = decoder.raw_decode(data_str) + content = obj.get("choices", [{}])[0].get("delta", {}).get("content", "") if content: yield content except: pass conn.close() @@ -131,61 +141,73 @@ class SXNGPlugin(Plugin): context_list = [f"[{i+1}] {r.get('title')}: {r.get('content')}" for i, r in enumerate(raw_results[:6])] context_str = "\n".join(context_list) + # Handshake token + tk = secrets.token_hex(16) + self.valid_tokens[tk] = time.time() + 60 + b64_context = base64.b64encode(context_str.encode('utf-8')).decode('utf-8') js_q = json.dumps(search.search_query.query) html_payload = f''' -
-

+ ''' diff --git a/test_standalone.py b/test_standalone.py index 3b6bd7a..3725acf 100644 --- a/test_standalone.py +++ b/test_standalone.py @@ -114,6 +114,11 @@ class PluginTestCase(unittest.TestCase): self.assertIn('/gemini-stream', content) def test_stream_endpoint(self): + # Trigger index to generate a token in the plugin instance + self.app.get('/') + # Extract the last generated token + token = list(plugin.valid_tokens.keys())[-1] + # Check for the appropriate key based on provider key = os.getenv("OPENROUTER_API_KEY") if plugin.provider == 'openrouter' else os.getenv("GEMINI_API_KEY") if not key: @@ -121,7 +126,8 @@ class PluginTestCase(unittest.TestCase): payload = { "q": "why is the sky blue", - "context": "The sky is blue because of Rayleigh scattering." + "context": "The sky is blue because of Rayleigh scattering.", + "tk": token } response = self.app.post('/gemini-stream', json=payload)