refactor: improve stream parsing
This commit is contained in:
+23
-19
@@ -51,29 +51,35 @@ class SXNGPlugin(Plugin):
|
|||||||
conn.request("POST", path, body=json.dumps(payload), headers={"Content-Type": "application/json"})
|
conn.request("POST", path, body=json.dumps(payload), headers={"Content-Type": "application/json"})
|
||||||
res = conn.getresponse()
|
res = conn.getresponse()
|
||||||
|
|
||||||
|
if res.status != 200:
|
||||||
|
yield f" [Error: {res.status} {res.reason} - {res.read().decode('utf-8')}]"
|
||||||
|
return
|
||||||
|
|
||||||
|
decoder = json.JSONDecoder()
|
||||||
buffer = ""
|
buffer = ""
|
||||||
|
|
||||||
for chunk in res:
|
for chunk in res:
|
||||||
if not chunk: continue
|
if not chunk: continue
|
||||||
buffer += chunk.decode('utf-8')
|
buffer += chunk.decode('utf-8')
|
||||||
while True:
|
|
||||||
start = buffer.find('{')
|
while buffer:
|
||||||
if start == -1: break
|
buffer = buffer.lstrip()
|
||||||
brace_count, end = 0, -1
|
if not buffer: break
|
||||||
for i in range(start, len(buffer)):
|
|
||||||
if buffer[i] == '{': brace_count += 1
|
|
||||||
elif buffer[i] == '}': brace_count -= 1
|
|
||||||
if brace_count == 0:
|
|
||||||
end = i + 1
|
|
||||||
break
|
|
||||||
if end == -1: break
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(buffer[start:end])
|
obj, idx = decoder.raw_decode(buffer)
|
||||||
candidates = data.get('candidates', [])
|
candidates = obj.get('candidates', [])
|
||||||
if candidates:
|
if candidates:
|
||||||
text = candidates[0]['content']['parts'][0]['text']
|
content = candidates[0].get('content', {})
|
||||||
|
parts = content.get('parts', [])
|
||||||
|
if parts:
|
||||||
|
text = parts[0].get('text', '')
|
||||||
if text: yield text
|
if text: yield text
|
||||||
except: pass
|
|
||||||
buffer = buffer[end:]
|
buffer = buffer[idx:]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
break
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield f" [Error: {str(e)}]"
|
yield f" [Error: {str(e)}]"
|
||||||
@@ -90,7 +96,6 @@ class SXNGPlugin(Plugin):
|
|||||||
context_list = [f"[{i+1}] {r.get('title')}: {r.get('content')}" for i, r in enumerate(raw_results[:6])]
|
context_list = [f"[{i+1}] {r.get('title')}: {r.get('content')}" for i, r in enumerate(raw_results[:6])]
|
||||||
context_str = "\n".join(context_list)
|
context_str = "\n".join(context_list)
|
||||||
|
|
||||||
# Base64 Encode to ensure HTML safety
|
|
||||||
b64_context = base64.b64encode(context_str.encode('utf-8')).decode('utf-8')
|
b64_context = base64.b64encode(context_str.encode('utf-8')).decode('utf-8')
|
||||||
js_q = json.dumps(search.search_query.query)
|
js_q = json.dumps(search.search_query.query)
|
||||||
|
|
||||||
@@ -109,7 +114,6 @@ class SXNGPlugin(Plugin):
|
|||||||
if (container && shell) {{ container.prepend(shell); shell.style.display = 'block'; }}
|
if (container && shell) {{ container.prepend(shell); shell.style.display = 'block'; }}
|
||||||
|
|
||||||
try {{
|
try {{
|
||||||
// Decode context client-side
|
|
||||||
const ctx = new TextDecoder().decode(Uint8Array.from(atob(b64), c => c.charCodeAt(0)));
|
const ctx = new TextDecoder().decode(Uint8Array.from(atob(b64), c => c.charCodeAt(0)));
|
||||||
|
|
||||||
const res = await fetch('/gemini-stream', {{
|
const res = await fetch('/gemini-stream', {{
|
||||||
@@ -131,5 +135,5 @@ class SXNGPlugin(Plugin):
|
|||||||
}})();
|
}})();
|
||||||
</script>
|
</script>
|
||||||
'''
|
'''
|
||||||
results.add(results.types.Answer(answer=Markup(html_payload)))
|
search.result_container.answers.add(results.types.Answer(answer=Markup(html_payload)))
|
||||||
return results
|
return results
|
||||||
|
|||||||
+40
-45
@@ -5,87 +5,62 @@ from types import ModuleType
|
|||||||
from flask import Flask, request
|
from flask import Flask, request
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# Configure logging to show INFO messages
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# Load environment variables from .env file
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# --- 1. Mock SearXNG dependencies BEFORE importing the plugin ---
|
|
||||||
# We create fake modules so gemini_flash.py can import 'searx.plugins' etc. without error.
|
|
||||||
|
|
||||||
searx = ModuleType("searx")
|
searx = ModuleType("searx")
|
||||||
searx_plugins = ModuleType("searx.plugins")
|
searx_plugins = ModuleType("searx.plugins")
|
||||||
searx_results = ModuleType("searx.result_types")
|
searx_results = ModuleType("searx.result_types")
|
||||||
|
|
||||||
class MockPlugin:
|
class MockPlugin:
|
||||||
"""Mocks searx.plugins.Plugin"""
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
pass
|
self.active = getattr(cfg, 'active', True)
|
||||||
|
|
||||||
class MockPluginInfo:
|
class MockPluginInfo:
|
||||||
"""Mocks searx.plugins.PluginInfo"""
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.meta = kwargs
|
self.meta = kwargs
|
||||||
|
|
||||||
class MockEngineResults:
|
class MockEngineResults:
|
||||||
"""Mocks searx.result_types.EngineResults"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# We need a 'types' object that has an 'Answer' class
|
|
||||||
self.types = ModuleType("types")
|
self.types = ModuleType("types")
|
||||||
# Handle both positional and keyword arguments for Answer
|
|
||||||
self.types.Answer = lambda *args, **kwargs: kwargs.get('answer', args[0] if args else "")
|
self.types.Answer = lambda *args, **kwargs: kwargs.get('answer', args[0] if args else "")
|
||||||
self._results = []
|
self._results = []
|
||||||
|
|
||||||
def add(self, res):
|
def add(self, res):
|
||||||
self._results.append(res)
|
self._results.append(res)
|
||||||
|
|
||||||
# Assign mocks to the fake modules
|
|
||||||
searx_plugins.Plugin = MockPlugin
|
searx_plugins.Plugin = MockPlugin
|
||||||
searx_plugins.PluginInfo = MockPluginInfo
|
searx_plugins.PluginInfo = MockPluginInfo
|
||||||
searx_results.EngineResults = MockEngineResults
|
searx_results.EngineResults = MockEngineResults
|
||||||
|
|
||||||
# Inject them into sys.modules
|
|
||||||
sys.modules["searx"] = searx
|
sys.modules["searx"] = searx
|
||||||
sys.modules["searx.plugins"] = searx_plugins
|
sys.modules["searx.plugins"] = searx_plugins
|
||||||
sys.modules["searx.result_types"] = searx_results
|
sys.modules["searx.result_types"] = searx_results
|
||||||
|
|
||||||
# --- 2. Import the actual plugin code ---
|
|
||||||
# Now that dependencies are mocked, we can import the file.
|
|
||||||
from gemini_flash import SXNGPlugin
|
from gemini_flash import SXNGPlugin
|
||||||
from flask_babel import Babel
|
from flask_babel import Babel
|
||||||
|
|
||||||
# --- 3. Setup the Test Harness ---
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
babel = Babel(app) # Initialize Babel to handle gettext calls if needed
|
babel = Babel(app)
|
||||||
|
|
||||||
# Mock the configuration object expected by the plugin
|
|
||||||
class MockConfig:
|
class MockConfig:
|
||||||
active = True
|
active = True
|
||||||
|
|
||||||
# Initialize the plugin
|
|
||||||
print("Initializing Plugin...")
|
|
||||||
if not os.getenv("GEMINI_API_KEY"):
|
|
||||||
print("WARNING: GEMINI_API_KEY environment variable is NOT set. The stream will likely fail.")
|
|
||||||
|
|
||||||
plugin = SXNGPlugin(MockConfig())
|
plugin = SXNGPlugin(MockConfig())
|
||||||
plugin.init(app) # This registers the /gemini-stream route
|
plugin.init(app)
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
print(">>> INDEX ROUTE HIT <<<")
|
|
||||||
"""
|
|
||||||
Simulates a search result page.
|
|
||||||
It calls post_search() to get the script, then embeds it in a basic HTML page.
|
|
||||||
"""
|
|
||||||
# 1. Create a Mock Search Object
|
|
||||||
class MockSearchQuery:
|
class MockSearchQuery:
|
||||||
pageno = 1
|
pageno = 1
|
||||||
query = request.args.get("q", "why is the sky blue") # Allow query via url param
|
query = request.args.get("q", "why is the sky blue")
|
||||||
|
|
||||||
class MockSearch:
|
class MockSearch:
|
||||||
search_query = MockSearchQuery()
|
search_query = MockSearchQuery()
|
||||||
class MockResultContainer:
|
class MockResultContainer:
|
||||||
|
def __init__(self):
|
||||||
|
self.answers = set()
|
||||||
|
|
||||||
def get_ordered_results(self):
|
def get_ordered_results(self):
|
||||||
return [
|
return [
|
||||||
{"title": "Fact About Sky", "content": "The sky is blue because of Rayleigh scattering."},
|
{"title": "Fact About Sky", "content": "The sky is blue because of Rayleigh scattering."},
|
||||||
@@ -94,13 +69,12 @@ def index():
|
|||||||
]
|
]
|
||||||
result_container = MockResultContainer()
|
result_container = MockResultContainer()
|
||||||
|
|
||||||
# 2. Run the Plugin's post_search hook
|
search = MockSearch()
|
||||||
results = plugin.post_search(None, MockSearch())
|
plugin.post_search(None, search)
|
||||||
|
|
||||||
# 3. Extract the injected HTML (if any)
|
|
||||||
injection_html = ""
|
injection_html = ""
|
||||||
if results._results:
|
if search.result_container.answers:
|
||||||
injection_html = results._results[0]
|
injection_html = list(search.result_container.answers)[0]
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
@@ -110,7 +84,6 @@ def index():
|
|||||||
<title>Plugin Test</title>
|
<title>Plugin Test</title>
|
||||||
<style>
|
<style>
|
||||||
body {{ font-family: sans-serif; padding: 2rem; max-width: 800px; margin: 0 auto; }}
|
body {{ font-family: sans-serif; padding: 2rem; max-width: 800px; margin: 0 auto; }}
|
||||||
/* Mimic SearXNG variables for the injection styles to work */
|
|
||||||
:root {{
|
:root {{
|
||||||
--color-result-border: #ccc;
|
--color-result-border: #ccc;
|
||||||
--color-result-description: #333;
|
--color-result-description: #333;
|
||||||
@@ -122,16 +95,38 @@ def index():
|
|||||||
<p>Testing query: <strong>{MockSearch.search_query.query}</strong></p>
|
<p>Testing query: <strong>{MockSearch.search_query.query}</strong></p>
|
||||||
<p><a href="/?q=tell me a joke">Try: "tell me a joke"</a> | <a href="/?q=explain quantum physics">Try: "explain quantum physics"</a></p>
|
<p><a href="/?q=tell me a joke">Try: "tell me a joke"</a> | <a href="/?q=explain quantum physics">Try: "explain quantum physics"</a></p>
|
||||||
<hr>
|
<hr>
|
||||||
|
|
||||||
<!-- The Plugin Injection -->
|
|
||||||
{injection_html}
|
{injection_html}
|
||||||
|
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
class PluginTestCase(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.app = app.test_client()
|
||||||
|
self.app.testing = True
|
||||||
|
|
||||||
|
def test_html_injection(self):
|
||||||
|
response = self.app.get('/')
|
||||||
|
content = response.data.decode('utf-8')
|
||||||
|
self.assertIn('<div id="ai-shell"', content)
|
||||||
|
self.assertIn('const q = "why is the sky blue";', content)
|
||||||
|
self.assertIn('/gemini-stream', content)
|
||||||
|
|
||||||
|
def test_stream_endpoint(self):
|
||||||
|
if not os.getenv("GEMINI_API_KEY"):
|
||||||
|
self.skipTest("GEMINI_API_KEY not set")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"q": "why is the sky blue",
|
||||||
|
"context": "The sky is blue because of Rayleigh scattering."
|
||||||
|
}
|
||||||
|
|
||||||
|
response = self.app.post('/gemini-stream', json=payload)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
data = response.data.decode('utf-8')
|
||||||
|
self.assertTrue(len(data) > 0)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("\n--- TEST SERVER RUNNING ---")
|
unittest.main()
|
||||||
print("1. Ensure GEMINI_API_KEY is set in your terminal.")
|
|
||||||
print("2. Open http://localhost:5000 in your browser.")
|
|
||||||
app.run(host='0.0.0.0', port=5000, debug=False)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user