Initial commit: SHARED library with LFS for binary assets
This commit is contained in:
@@ -0,0 +1,302 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import zstandard as zstd
|
||||
from aiohttp import web
|
||||
|
||||
from relay_gateway.keys import KeyStore, Grant, can_access
|
||||
from relay_gateway.channels import ChannelConfig, build_channels
|
||||
|
||||
logger = logging.getLogger("relay-gateway")
|
||||
|
||||
HOP_BY_HOP_HEADERS = {
|
||||
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
|
||||
"te", "trailers", "transfer-encoding", "upgrade",
|
||||
}
|
||||
_compressor = zstd.ZstdCompressor(level=3)
|
||||
_TRUNCATE_BYTES = int(os.getenv("SREBOT_EXTERNAL_TRUNCATE_BYTES", str(100 * 1024 * 1024)))
|
||||
|
||||
|
||||
def _grant(request: web.Request) -> Grant | None:
|
||||
store: KeyStore = request.app["key_store"]
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth[7:] if auth.startswith("Bearer ") else ""
|
||||
return store.resolve(token)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def auth_middleware(request: web.Request, handler):
|
||||
if request.path in {"/health", "/"}:
|
||||
return await handler(request)
|
||||
grant = _grant(request)
|
||||
if grant is None:
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
request["grant"] = grant
|
||||
return await handler(request)
|
||||
|
||||
|
||||
async def health(_: web.Request) -> web.Response:
|
||||
return web.json_response({"status": "ok", "service": "relay-gateway"})
|
||||
|
||||
|
||||
async def root(_: web.Request) -> web.Response:
|
||||
return web.json_response({"service": "relay-gateway",
|
||||
"message": "Use /api/<sre|tss>/* and /ws/<sre|tss>."})
|
||||
|
||||
|
||||
async def whoami(request: web.Request) -> web.Response:
|
||||
grant: Grant = request["grant"]
|
||||
return web.json_response(
|
||||
{"name": grant.name, "level": grant.level, "channels": list(grant.channels)}
|
||||
)
|
||||
|
||||
|
||||
async def proxy_api(request: web.Request) -> web.StreamResponse:
|
||||
channel = request.match_info["channel"]
|
||||
channels: dict[str, ChannelConfig] = request.app["channels"]
|
||||
grant: Grant = request["grant"]
|
||||
if channel not in channels:
|
||||
if channel == "sqb":
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"unknown channel '{channel}'",
|
||||
"detail": "The 'sqb' channel was renamed to 'sre'. Use /api/sre/ instead, and register new API keys with the 'sre' level (legacy 'sqb'-level keys still work as an alias).",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response({"error": f"unknown channel {channel}"}, status=404)
|
||||
if not can_access(grant, channel):
|
||||
return web.json_response({"error": "Forbidden for this key"}, status=403)
|
||||
cfg = channels[channel]
|
||||
if cfg.upstream_url is None:
|
||||
return web.json_response(
|
||||
{"error": f"/api/{channel}/* not implemented yet"}, status=501
|
||||
)
|
||||
tail = request.match_info["tail"]
|
||||
target = f"{cfg.upstream_url}{cfg.upstream_path}{tail}"
|
||||
if request.query_string:
|
||||
target = f"{target}?{request.query_string}"
|
||||
body = await request.read() if request.can_read_body else b""
|
||||
headers = {"Accept": "application/json"}
|
||||
upstream_token = os.getenv("SREBOT_EXTERNAL_UPSTREAM_BEARER_TOKEN",
|
||||
os.getenv("SREBOT_API_BEARER_TOKEN", "")).strip()
|
||||
if upstream_token:
|
||||
headers["Authorization"] = f"Bearer {upstream_token}"
|
||||
try:
|
||||
async with request.app["http_session"].request(
|
||||
request.method, target, headers=headers, data=body or None
|
||||
) as upstream:
|
||||
payload = await upstream.read()
|
||||
out_headers = {
|
||||
k: v for k, v in upstream.headers.items()
|
||||
if k.lower() not in HOP_BY_HOP_HEADERS
|
||||
and k.lower() not in {"content-length", "content-encoding"}
|
||||
}
|
||||
# Rewrite info responses so endpoint paths include the channel prefix.
|
||||
if tail == "info":
|
||||
try:
|
||||
body_obj = json.loads(payload)
|
||||
prefix = f"/api/{channel}"
|
||||
if "endpoints" in body_obj:
|
||||
rewritten = {}
|
||||
for method_path, desc in body_obj["endpoints"].items():
|
||||
parts = method_path.split(" ", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
|
||||
rewritten[f"{parts[0]} /api/{channel}{parts[1][4:]}"] = desc
|
||||
else:
|
||||
rewritten[method_path] = desc
|
||||
body_obj["endpoints"] = rewritten
|
||||
if "availableEndpoints" in body_obj:
|
||||
rewritten = []
|
||||
for ep in body_obj["availableEndpoints"]:
|
||||
parts = ep.split(" ", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
|
||||
rewritten.append(f"{parts[0]} /api/{channel}{parts[1][4:]}")
|
||||
elif not parts[1:]:
|
||||
rewritten.append(ep)
|
||||
else:
|
||||
rewritten.append(ep)
|
||||
body_obj["availableEndpoints"] = rewritten
|
||||
payload = json.dumps(body_obj, indent=2).encode()
|
||||
except Exception:
|
||||
pass
|
||||
return web.Response(body=payload, status=upstream.status, headers=out_headers)
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.warning("upstream error channel=%s: %s", channel, exc)
|
||||
return web.json_response({"error": "upstream unavailable"}, status=502)
|
||||
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
channel = request.match_info["channel"]
|
||||
channels: dict[str, ChannelConfig] = request.app["channels"]
|
||||
grant = _grant(request)
|
||||
ws = web.WebSocketResponse(heartbeat=20)
|
||||
if channel not in channels or grant is None or not can_access(grant, channel):
|
||||
await ws.prepare(request)
|
||||
await ws.close(code=1008, message=b"Unauthorized")
|
||||
return ws
|
||||
await ws.prepare(request)
|
||||
clients: set[web.WebSocketResponse] = request.app["clients"][channel]
|
||||
clients.add(ws)
|
||||
logger.info("ws connected channel=%s clients=%d", channel, len(clients))
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.warning("ws error: %s", ws.exception())
|
||||
finally:
|
||||
clients.discard(ws)
|
||||
logger.info("ws disconnected channel=%s clients=%d", channel, len(clients))
|
||||
return ws
|
||||
|
||||
|
||||
async def _broadcast(app: web.Application, channel: str, envelope: dict[str, Any]) -> None:
|
||||
raw = json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||
payload = _compressor.compress(raw)
|
||||
targets = list(app["clients"][channel])
|
||||
if not targets:
|
||||
return
|
||||
dead = []
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.send_bytes(payload)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
app["clients"][channel].discard(ws)
|
||||
logger.info("broadcast channel=%s clients=%d raw=%d zstd=%d",
|
||||
channel, len(targets) - len(dead), len(raw), len(payload))
|
||||
|
||||
|
||||
def _read_offset(p: Path) -> int:
|
||||
try:
|
||||
return int(p.read_text(encoding="utf-8").strip())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _write_offset(p: Path, offset: int) -> None:
|
||||
p.write_text(str(offset), encoding="utf-8")
|
||||
|
||||
|
||||
def _maybe_truncate(outbox: Path, position: int) -> int:
|
||||
try:
|
||||
size = outbox.stat().st_size
|
||||
if position >= _TRUNCATE_BYTES and position == size:
|
||||
with outbox.open("r+b") as handle:
|
||||
handle.truncate(0)
|
||||
_write_offset(outbox.with_suffix(".offset"), 0)
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("truncate failed: %s", exc)
|
||||
return position
|
||||
|
||||
|
||||
async def relay_loop(app: web.Application, cfg: ChannelConfig) -> None:
|
||||
outbox = cfg.outbox_path
|
||||
offset_path = outbox.with_suffix(".offset")
|
||||
position = _read_offset(offset_path)
|
||||
delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if not outbox.exists():
|
||||
await asyncio.sleep(1.0)
|
||||
continue
|
||||
size = outbox.stat().st_size
|
||||
if position > size:
|
||||
position = 0
|
||||
_write_offset(offset_path, position)
|
||||
with outbox.open("r", encoding="utf-8") as handle:
|
||||
handle.seek(position)
|
||||
line = handle.readline()
|
||||
if not line:
|
||||
position = _maybe_truncate(outbox, position)
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
position = handle.tell()
|
||||
_write_offset(offset_path, position)
|
||||
continue
|
||||
position = handle.tell()
|
||||
_write_offset(offset_path, position)
|
||||
await _broadcast(app, cfg.name, envelope)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("relay loop error channel=%s: %s", cfg.name, exc)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, 30.0)
|
||||
|
||||
|
||||
async def _http_session_ctx(app: web.Application):
|
||||
app["http_session"] = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await app["http_session"].close()
|
||||
|
||||
|
||||
async def _relay_tasks_ctx(app: web.Application):
|
||||
tasks = [asyncio.create_task(relay_loop(app, cfg))
|
||||
for cfg in app["channels"].values()]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
def create_app(*, key_store: KeyStore, channels: dict[str, ChannelConfig]) -> web.Application:
|
||||
app = web.Application(middlewares=[auth_middleware])
|
||||
app["key_store"] = key_store
|
||||
app["channels"] = channels
|
||||
app["clients"] = {name: set() for name in channels}
|
||||
app.router.add_get("/", root)
|
||||
app.router.add_get("/health", health)
|
||||
app.router.add_get("/api/whoami", whoami)
|
||||
app.router.add_get("/ws/{channel}", websocket_handler)
|
||||
app.router.add_route("*", "/api/{channel}/{tail:.*}", proxy_api)
|
||||
app.cleanup_ctx.append(_http_session_ctx)
|
||||
app.cleanup_ctx.append(_relay_tasks_ctx)
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
from dotenv import load_dotenv
|
||||
# The gateway runs from BOTS/SHARED; load SHARED/.env for gateway config.
|
||||
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="[%(asctime)s] [%(levelname)s] [relay-gateway] %(message)s")
|
||||
storage_root = Path(os.environ["STORAGE_VOL_PATH"])
|
||||
key_path = Path(os.getenv("RELAY_KEYS_PATH", str(storage_root / "relay_keys.json")))
|
||||
channels = build_channels(
|
||||
storage_root,
|
||||
sre_upstream=os.getenv("SREBOT_EXTERNAL_UPSTREAM_URL", "http://127.0.0.1:6000"),
|
||||
tss_upstream=os.getenv("TSS_EXTERNAL_UPSTREAM_URL") or None,
|
||||
)
|
||||
app = create_app(key_store=KeyStore(key_path), channels=channels)
|
||||
web.run_app(
|
||||
app,
|
||||
host=os.getenv("SREBOT_EXTERNAL_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("SREBOT_EXTERNAL_PORT", "18081")),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user