#!/usr/bin/env python3 """External SREBOT bridge service. This PM2-managed process does two things: 1. Proxies read-only SREBOT queries on the external port. 2. Broadcasts SREBOT replay envelopes over websocket to any connected client. """ from __future__ import annotations import asyncio import json import logging import os import sys import time from dataclasses import dataclass from pathlib import Path from typing import Any import aiohttp import zstandard as zstd from aiohttp import web from dotenv import load_dotenv sys.path.insert(0, str(Path(__file__).resolve().parents[1])) load_dotenv() from BOT.receiver_bridge import EXTERNAL_OUTBOX_PATH # noqa: E402 logging.basicConfig( level=logging.INFO, format="[%(asctime)s] [%(levelname)s] [srebot-external] %(message)s", ) logger = logging.getLogger("srebot-external") def _env(name: str, default: str = "") -> str: return os.getenv(name, default).strip() @dataclass(slots=True) class ExternalSettings: host: str = _env("SREBOT_EXTERNAL_HOST", "0.0.0.0") port: int = int(_env("SREBOT_EXTERNAL_PORT", "18081")) bearer_token: str = _env("SREBOT_EXTERNAL_BEARER_TOKEN", _env("SREBOT_API_BEARER_TOKEN")) upstream_url: str = _env("SREBOT_EXTERNAL_UPSTREAM_URL", "http://127.0.0.1:6000").rstrip("/") upstream_bearer_token: str = _env("SREBOT_EXTERNAL_UPSTREAM_BEARER_TOKEN", _env("SREBOT_API_BEARER_TOKEN")) outbox_path: Path = Path(_env("SREBOT_EXTERNAL_OUTBOX_PATH", str(EXTERNAL_OUTBOX_PATH))) offset_path: Path = Path(_env("SREBOT_EXTERNAL_OFFSET_PATH", str(Path(str(EXTERNAL_OUTBOX_PATH)).with_suffix(".offset")))) poll_interval_seconds: float = float(_env("SREBOT_EXTERNAL_POLL_INTERVAL", "0.5")) reconnect_delay_seconds: float = float(_env("SREBOT_EXTERNAL_RECONNECT_DELAY", "1.0")) SETTINGS = ExternalSettings() SETTINGS.outbox_path.parent.mkdir(parents=True, exist_ok=True) SETTINGS.offset_path.parent.mkdir(parents=True, exist_ok=True) HOP_BY_HOP_HEADERS = { "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", } CONNECTED_WEBSOCKETS: set[web.WebSocketResponse] = set() CONNECTED_LOCK = asyncio.Lock() _compressor = zstd.ZstdCompressor(level=3) def _auth_ok(request: web.Request) -> bool: if not SETTINGS.bearer_token: return True return request.headers.get("Authorization", "") == f"Bearer {SETTINGS.bearer_token}" @web.middleware async def auth_middleware(request: web.Request, handler): if request.path in {"/health", "/"} or request.path.startswith("/ws/"): return await handler(request) if not _auth_ok(request): logger.warning("Unauthorized request", extra={"path": request.rel_url.path_qs}) return web.json_response({"error": "Unauthorized"}, status=401) return await handler(request) def _upstream_headers() -> dict[str, str]: headers = {"Accept": "application/json"} if SETTINGS.upstream_bearer_token: headers["Authorization"] = f"Bearer {SETTINGS.upstream_bearer_token}" return headers def _read_offset() -> int: try: return int(SETTINGS.offset_path.read_text(encoding="utf-8").strip()) except Exception: return 0 def _write_offset(offset: int) -> None: SETTINGS.offset_path.write_text(str(offset), encoding="utf-8") async def health(_: web.Request) -> web.Response: return web.json_response( { "status": "ok", "service": "srebot-external", "http": SETTINGS.upstream_url, "websocket": "/ws/srebot", } ) async def proxy_api(request: web.Request) -> web.StreamResponse: target = f"{SETTINGS.upstream_url}{request.rel_url.path_qs}" request_start = time.monotonic() logger.info( "AXBot query in", extra={ "method": request.method, "path": request.rel_url.path_qs, }, ) body = await request.read() if request.can_read_body else b"" async with request.app["http_session"].request( request.method, target, headers=_upstream_headers(), data=body if body else None, ) as upstream: payload = await upstream.read() duration_ms = round((time.monotonic() - request_start) * 1000, 1) logger.info( "AXBot query out", extra={ "method": request.method, "path": request.rel_url.path_qs, "status": upstream.status, "bytes": len(payload), "duration_ms": duration_ms, }, ) headers = { key: value for key, value in upstream.headers.items() if key.lower() not in HOP_BY_HOP_HEADERS and key.lower() not in {"content-length", "content-encoding"} } return web.Response(body=payload, status=upstream.status, headers=headers) async def root(_: web.Request) -> web.Response: return web.json_response( { "service": "srebot-external", "message": "Use /api/* for queries and /ws/srebot for replay events.", } ) async def websocket_handler(request: web.Request) -> web.WebSocketResponse: if not _auth_ok(request): logger.warning("Unauthorized websocket", extra={"path": request.rel_url.path_qs}) ws = web.WebSocketResponse() await ws.prepare(request) await ws.close(code=1008, message=b"Unauthorized") return ws ws = web.WebSocketResponse(heartbeat=20) await ws.prepare(request) async with CONNECTED_LOCK: CONNECTED_WEBSOCKETS.add(ws) logger.info("Websocket connected", extra={"clients": len(CONNECTED_WEBSOCKETS)}) try: async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: logger.info("Websocket recv", extra={"bytes": len(msg.data)}) elif msg.type == aiohttp.WSMsgType.ERROR: logger.warning("Websocket error", extra={"error": str(ws.exception())}) finally: async with CONNECTED_LOCK: CONNECTED_WEBSOCKETS.discard(ws) logger.info("Websocket disconnected", extra={"clients": len(CONNECTED_WEBSOCKETS)}) return ws async def _broadcast(envelope: dict[str, Any]) -> None: raw = json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8") payload = _compressor.compress(raw) async with CONNECTED_LOCK: targets = list(CONNECTED_WEBSOCKETS) if not targets: logger.info( "No websocket clients connected", extra={"event_type": envelope.get("type")}, ) return dead: list[web.WebSocketResponse] = [] for ws in targets: try: await ws.send_bytes(payload) except Exception as exc: logger.warning( "Failed to send websocket envelope", extra={"event_type": envelope.get("type"), "error": str(exc)}, ) dead.append(ws) if dead: async with CONNECTED_LOCK: for ws in dead: CONNECTED_WEBSOCKETS.discard(ws) logger.info( "Websocket broadcast", extra={ "event_type": envelope.get("type"), "clients": len(targets) - len(dead), "raw_bytes": len(raw), "compressed_bytes": len(payload), "payload_keys": list((envelope.get("payload") or {}).keys())[:8], }, ) # Truncate the outbox once we've consumed past this many bytes. The file is # append-only and previously grew unbounded — we observed it at 1.9 GB on disk # with all data already relayed and offset matching size. Truncating when # fully caught up keeps disk usage flat without cooperating with the writer. # Race: a writer in the BOT process may append between the size check and # the truncate call. Those envelopes would be lost, but envelopes here are # best-effort match-replay events; rare loss during a 100 MB-scale rotation # is acceptable. _OUTBOX_TRUNCATE_THRESHOLD_BYTES = int(_env("SREBOT_EXTERNAL_TRUNCATE_BYTES", str(100 * 1024 * 1024))) def _maybe_truncate_outbox(position: int) -> int: try: current_size = SETTINGS.outbox_path.stat().st_size if ( position >= _OUTBOX_TRUNCATE_THRESHOLD_BYTES and position == current_size ): with SETTINGS.outbox_path.open("r+b") as handle: handle.truncate(0) _write_offset(0) logger.info( "Outbox caught up; truncated", extra={"reclaimed_bytes": position}, ) return 0 except FileNotFoundError: pass except Exception as exc: logger.warning("Outbox truncate failed", extra={"error": str(exc)}) return position async def relay_outbox_loop(app: web.Application) -> None: reconnect_delay = SETTINGS.reconnect_delay_seconds position = _read_offset() while True: try: if not SETTINGS.outbox_path.exists(): await asyncio.sleep(1.0) continue current_size = SETTINGS.outbox_path.stat().st_size if position > current_size: logger.info( "Outbox truncated; resetting offset", extra={"old_offset": position, "current_size": current_size}, ) position = 0 _write_offset(position) with SETTINGS.outbox_path.open("r", encoding="utf-8") as handle: handle.seek(position) line = handle.readline() if not line: position = _maybe_truncate_outbox(position) await asyncio.sleep(SETTINGS.poll_interval_seconds) continue try: envelope = json.loads(line) except json.JSONDecodeError: position = handle.tell() _write_offset(position) logger.warning("Skipping malformed outbox line", extra={"offset": position}) continue position = handle.tell() _write_offset(position) await _broadcast(envelope) except asyncio.CancelledError: raise except Exception as exc: logger.warning( "Bridge loop error", extra={"error": str(exc), "retry_in_seconds": reconnect_delay}, ) await asyncio.sleep(reconnect_delay) reconnect_delay = min(reconnect_delay * 2, 30.0) async def create_http_session(app: web.Application): app["http_session"] = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) try: yield finally: await app["http_session"].close() async def start_relay_task(app: web.Application): task = asyncio.create_task(relay_outbox_loop(app)) app["relay_task"] = task try: yield finally: task.cancel() try: await task except asyncio.CancelledError: pass def create_app() -> web.Application: app = web.Application(middlewares=[auth_middleware]) app.router.add_get("/", root) app.router.add_get("/health", health) app.router.add_get("/ws/srebot", websocket_handler) app.router.add_route("*", "/api/{tail:.*}", proxy_api) app.cleanup_ctx.append(create_http_session) app.cleanup_ctx.append(start_relay_task) return app def main() -> None: web.run_app(create_app(), host=SETTINGS.host, port=SETTINGS.port) if __name__ == "__main__": main()