diff --git a/gemini_flash.py b/gemini_flash.py
index 00080f5..0045360 100644
--- a/gemini_flash.py
+++ b/gemini_flash.py
@@ -1,5 +1,5 @@
-import json, http.client, ssl, os, logging, base64
-from flask import Response, request
+import json, http.client, ssl, os, logging, base64, secrets, time
+from flask import Response, request, abort
from searx.plugins import Plugin, PluginInfo
from searx.result_types import EngineResults
from flask_babel import gettext
@@ -24,11 +24,22 @@ class SXNGPlugin(Plugin):
self.max_tokens = int(os.getenv('GEMINI_MAX_TOKENS', 500))
self.temperature = float(os.getenv('GEMINI_TEMPERATURE', 0.2))
self.base_url = os.getenv('OPENROUTER_BASE_URL', 'openrouter.ai')
+ self.valid_tokens = {}
def init(self, app):
@app.route('/gemini-stream', methods=['POST'])
def g_stream():
data = request.json or {}
+ token = data.get('tk', '')
+
+ # Maintenance: Token validation & cleanup
+ now = time.time()
+ self.valid_tokens = {k: v for k, v in self.valid_tokens.items() if v > now}
+
+ if token not in self.valid_tokens:
+ abort(403)
+ del self.valid_tokens[token]
+
context_text = data.get('context', '')
q = data.get('q', '')
@@ -61,7 +72,6 @@ class SXNGPlugin(Plugin):
chunk = res.read(128)
if not chunk: break
buffer += chunk.decode('utf-8')
-
while buffer:
buffer = buffer.lstrip()
if not buffer: break
@@ -99,6 +109,7 @@ class SXNGPlugin(Plugin):
res = conn.getresponse()
if res.status != 200: return
+ decoder = json.JSONDecoder()
buffer = ""
while True:
chunk = res.read(128)
@@ -106,13 +117,12 @@ class SXNGPlugin(Plugin):
buffer += chunk.decode('utf-8')
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
- line = line.strip()
if line.startswith("data: "):
data_str = line[6:].strip()
if data_str == "[DONE]": return
try:
- data_json = json.loads(data_str)
- content = data_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
+ obj, _ = decoder.raw_decode(data_str)
+ content = obj.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content: yield content
except: pass
conn.close()
@@ -131,61 +141,73 @@ class SXNGPlugin(Plugin):
context_list = [f"[{i+1}] {r.get('title')}: {r.get('content')}" for i, r in enumerate(raw_results[:6])]
context_str = "\n".join(context_list)
+ # Handshake token
+ tk = secrets.token_hex(16)
+ self.valid_tokens[tk] = time.time() + 60
+
b64_context = base64.b64encode(context_str.encode('utf-8')).decode('utf-8')
js_q = json.dumps(search.search_query.query)
html_payload = f'''
-
-
+
+
+ Thinking...
+
'''
diff --git a/test_standalone.py b/test_standalone.py
index 3b6bd7a..3725acf 100644
--- a/test_standalone.py
+++ b/test_standalone.py
@@ -114,6 +114,11 @@ class PluginTestCase(unittest.TestCase):
self.assertIn('/gemini-stream', content)
def test_stream_endpoint(self):
+ # Trigger index to generate a token in the plugin instance
+ self.app.get('/')
+ # Extract the last generated token
+ token = list(plugin.valid_tokens.keys())[-1]
+
# Check for the appropriate key based on provider
key = os.getenv("OPENROUTER_API_KEY") if plugin.provider == 'openrouter' else os.getenv("GEMINI_API_KEY")
if not key:
@@ -121,7 +126,8 @@ class PluginTestCase(unittest.TestCase):
payload = {
"q": "why is the sky blue",
- "context": "The sky is blue because of Rayleigh scattering."
+ "context": "The sky is blue because of Rayleigh scattering.",
+ "tk": token
}
response = self.app.post('/gemini-stream', json=payload)