from __future__ import annotations import asyncio import json import logging import os import time from pathlib import Path from typing import Any import aiohttp import zstandard as zstd from aiohttp import web from relay_gateway.keys import KeyStore, Grant, can_access from relay_gateway.channels import ChannelConfig, build_channels logger = logging.getLogger("relay-gateway") HOP_BY_HOP_HEADERS = { "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", } _compressor = zstd.ZstdCompressor(level=3) _TRUNCATE_BYTES = int(os.getenv("SREBOT_EXTERNAL_TRUNCATE_BYTES", str(100 * 1024 * 1024))) def _grant(request: web.Request) -> Grant | None: store: KeyStore = request.app["key_store"] auth = request.headers.get("Authorization", "") token = auth[7:] if auth.startswith("Bearer ") else "" return store.resolve(token) @web.middleware async def auth_middleware(request: web.Request, handler): if request.path in {"/health", "/"}: return await handler(request) grant = _grant(request) if grant is None: return web.json_response({"error": "Unauthorized"}, status=401) request["grant"] = grant return await handler(request) async def health(_: web.Request) -> web.Response: return web.json_response({"status": "ok", "service": "relay-gateway"}) async def root(_: web.Request) -> web.Response: return web.json_response({"service": "relay-gateway", "message": "Use /api//* and /ws/."}) async def whoami(request: web.Request) -> web.Response: grant: Grant = request["grant"] return web.json_response( {"name": grant.name, "level": grant.level, "channels": list(grant.channels)} ) async def proxy_api(request: web.Request) -> web.StreamResponse: channel = request.match_info["channel"] channels: dict[str, ChannelConfig] = request.app["channels"] grant: Grant = request["grant"] if channel not in channels: if channel == "sqb": return web.json_response( { "error": f"unknown channel '{channel}'", "detail": "The 'sqb' channel was renamed to 'sre'. Use /api/sre/ instead, and register new API keys with the 'sre' level (legacy 'sqb'-level keys still work as an alias).", }, status=404, ) return web.json_response({"error": f"unknown channel {channel}"}, status=404) if not can_access(grant, channel): return web.json_response({"error": "Forbidden for this key"}, status=403) cfg = channels[channel] if cfg.upstream_url is None: return web.json_response( {"error": f"/api/{channel}/* not implemented yet"}, status=501 ) tail = request.match_info["tail"] target = f"{cfg.upstream_url}{cfg.upstream_path}{tail}" if request.query_string: target = f"{target}?{request.query_string}" body = await request.read() if request.can_read_body else b"" headers = {"Accept": "application/json"} upstream_token = os.getenv("SREBOT_EXTERNAL_UPSTREAM_BEARER_TOKEN", os.getenv("SREBOT_API_BEARER_TOKEN", "")).strip() if upstream_token: headers["Authorization"] = f"Bearer {upstream_token}" try: async with request.app["http_session"].request( request.method, target, headers=headers, data=body or None ) as upstream: payload = await upstream.read() out_headers = { k: v for k, v in upstream.headers.items() if k.lower() not in HOP_BY_HOP_HEADERS and k.lower() not in {"content-length", "content-encoding"} } # Rewrite info responses so endpoint paths include the channel prefix. if tail == "info": try: body_obj = json.loads(payload) prefix = f"/api/{channel}" if "endpoints" in body_obj: rewritten = {} for method_path, desc in body_obj["endpoints"].items(): parts = method_path.split(" ", 1) if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"): rewritten[f"{parts[0]} /api/{channel}{parts[1][4:]}"] = desc else: rewritten[method_path] = desc body_obj["endpoints"] = rewritten if "availableEndpoints" in body_obj: rewritten = [] for ep in body_obj["availableEndpoints"]: parts = ep.split(" ", 1) if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"): rewritten.append(f"{parts[0]} /api/{channel}{parts[1][4:]}") elif not parts[1:]: rewritten.append(ep) else: rewritten.append(ep) body_obj["availableEndpoints"] = rewritten payload = json.dumps(body_obj, indent=2).encode() except Exception: pass return web.Response(body=payload, status=upstream.status, headers=out_headers) except aiohttp.ClientError as exc: logger.warning("upstream error channel=%s: %s", channel, exc) return web.json_response({"error": "upstream unavailable"}, status=502) async def websocket_handler(request: web.Request) -> web.WebSocketResponse: channel = request.match_info["channel"] channels: dict[str, ChannelConfig] = request.app["channels"] grant = _grant(request) ws = web.WebSocketResponse(heartbeat=20) if channel not in channels or grant is None or not can_access(grant, channel): await ws.prepare(request) await ws.close(code=1008, message=b"Unauthorized") return ws await ws.prepare(request) clients: set[web.WebSocketResponse] = request.app["clients"][channel] clients.add(ws) logger.info("ws connected channel=%s clients=%d", channel, len(clients)) try: async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: logger.warning("ws error: %s", ws.exception()) finally: clients.discard(ws) logger.info("ws disconnected channel=%s clients=%d", channel, len(clients)) return ws async def _broadcast(app: web.Application, channel: str, envelope: dict[str, Any]) -> None: raw = json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8") payload = _compressor.compress(raw) targets = list(app["clients"][channel]) if not targets: return dead = [] for ws in targets: try: await ws.send_bytes(payload) except Exception: dead.append(ws) for ws in dead: app["clients"][channel].discard(ws) logger.info("broadcast channel=%s clients=%d raw=%d zstd=%d", channel, len(targets) - len(dead), len(raw), len(payload)) def _read_offset(p: Path) -> int: try: return int(p.read_text(encoding="utf-8").strip()) except Exception: return 0 def _write_offset(p: Path, offset: int) -> None: p.write_text(str(offset), encoding="utf-8") def _maybe_truncate(outbox: Path, position: int) -> int: try: size = outbox.stat().st_size if position >= _TRUNCATE_BYTES and position == size: with outbox.open("r+b") as handle: handle.truncate(0) _write_offset(outbox.with_suffix(".offset"), 0) return 0 except FileNotFoundError: pass except Exception as exc: logger.warning("truncate failed: %s", exc) return position async def relay_loop(app: web.Application, cfg: ChannelConfig) -> None: outbox = cfg.outbox_path offset_path = outbox.with_suffix(".offset") position = _read_offset(offset_path) delay = 1.0 while True: try: if not outbox.exists(): await asyncio.sleep(1.0) continue size = outbox.stat().st_size if position > size: position = 0 _write_offset(offset_path, position) with outbox.open("r", encoding="utf-8") as handle: handle.seek(position) line = handle.readline() if not line: position = _maybe_truncate(outbox, position) await asyncio.sleep(0.5) continue try: envelope = json.loads(line) except json.JSONDecodeError: position = handle.tell() _write_offset(offset_path, position) continue position = handle.tell() _write_offset(offset_path, position) await _broadcast(app, cfg.name, envelope) except asyncio.CancelledError: raise except Exception as exc: logger.warning("relay loop error channel=%s: %s", cfg.name, exc) await asyncio.sleep(delay) delay = min(delay * 2, 30.0) async def _http_session_ctx(app: web.Application): app["http_session"] = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) try: yield finally: await app["http_session"].close() async def _relay_tasks_ctx(app: web.Application): tasks = [asyncio.create_task(relay_loop(app, cfg)) for cfg in app["channels"].values()] try: yield finally: for t in tasks: t.cancel() for t in tasks: try: await t except asyncio.CancelledError: pass def create_app(*, key_store: KeyStore, channels: dict[str, ChannelConfig]) -> web.Application: app = web.Application(middlewares=[auth_middleware]) app["key_store"] = key_store app["channels"] = channels app["clients"] = {name: set() for name in channels} app.router.add_get("/", root) app.router.add_get("/health", health) app.router.add_get("/api/whoami", whoami) app.router.add_get("/ws/{channel}", websocket_handler) app.router.add_route("*", "/api/{channel}/{tail:.*}", proxy_api) app.cleanup_ctx.append(_http_session_ctx) app.cleanup_ctx.append(_relay_tasks_ctx) return app def main() -> None: from dotenv import load_dotenv # The gateway runs from BOTS/SHARED; load SHARED/.env for gateway config. load_dotenv(Path(__file__).resolve().parents[1] / ".env") logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(levelname)s] [relay-gateway] %(message)s") storage_root = Path(os.environ["STORAGE_VOL_PATH"]) key_path = Path(os.getenv("RELAY_KEYS_PATH", str(storage_root / "relay_keys.json"))) channels = build_channels( storage_root, sre_upstream=os.getenv("SREBOT_EXTERNAL_UPSTREAM_URL", "http://127.0.0.1:6000"), tss_upstream=os.getenv("TSS_EXTERNAL_UPSTREAM_URL") or None, ) app = create_app(key_store=KeyStore(key_path), channels=channels) web.run_app( app, host=os.getenv("SREBOT_EXTERNAL_HOST", "0.0.0.0"), port=int(os.getenv("SREBOT_EXTERNAL_PORT", "18081")), ) if __name__ == "__main__": main()