Initial commit: SHARED library with LFS for binary assets

This commit is contained in:
clxud
2026-07-02 02:00:46 +00:00
commit db5de3ac7d
9356 changed files with 47608 additions and 0 deletions
+302
View File
@@ -0,0 +1,302 @@
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from pathlib import Path
from typing import Any
import aiohttp
import zstandard as zstd
from aiohttp import web
from relay_gateway.keys import KeyStore, Grant, can_access
from relay_gateway.channels import ChannelConfig, build_channels
logger = logging.getLogger("relay-gateway")
HOP_BY_HOP_HEADERS = {
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
"te", "trailers", "transfer-encoding", "upgrade",
}
_compressor = zstd.ZstdCompressor(level=3)
_TRUNCATE_BYTES = int(os.getenv("SREBOT_EXTERNAL_TRUNCATE_BYTES", str(100 * 1024 * 1024)))
def _grant(request: web.Request) -> Grant | None:
store: KeyStore = request.app["key_store"]
auth = request.headers.get("Authorization", "")
token = auth[7:] if auth.startswith("Bearer ") else ""
return store.resolve(token)
@web.middleware
async def auth_middleware(request: web.Request, handler):
if request.path in {"/health", "/"}:
return await handler(request)
grant = _grant(request)
if grant is None:
return web.json_response({"error": "Unauthorized"}, status=401)
request["grant"] = grant
return await handler(request)
async def health(_: web.Request) -> web.Response:
return web.json_response({"status": "ok", "service": "relay-gateway"})
async def root(_: web.Request) -> web.Response:
return web.json_response({"service": "relay-gateway",
"message": "Use /api/<sre|tss>/* and /ws/<sre|tss>."})
async def whoami(request: web.Request) -> web.Response:
grant: Grant = request["grant"]
return web.json_response(
{"name": grant.name, "level": grant.level, "channels": list(grant.channels)}
)
async def proxy_api(request: web.Request) -> web.StreamResponse:
channel = request.match_info["channel"]
channels: dict[str, ChannelConfig] = request.app["channels"]
grant: Grant = request["grant"]
if channel not in channels:
if channel == "sqb":
return web.json_response(
{
"error": f"unknown channel '{channel}'",
"detail": "The 'sqb' channel was renamed to 'sre'. Use /api/sre/ instead, and register new API keys with the 'sre' level (legacy 'sqb'-level keys still work as an alias).",
},
status=404,
)
return web.json_response({"error": f"unknown channel {channel}"}, status=404)
if not can_access(grant, channel):
return web.json_response({"error": "Forbidden for this key"}, status=403)
cfg = channels[channel]
if cfg.upstream_url is None:
return web.json_response(
{"error": f"/api/{channel}/* not implemented yet"}, status=501
)
tail = request.match_info["tail"]
target = f"{cfg.upstream_url}{cfg.upstream_path}{tail}"
if request.query_string:
target = f"{target}?{request.query_string}"
body = await request.read() if request.can_read_body else b""
headers = {"Accept": "application/json"}
upstream_token = os.getenv("SREBOT_EXTERNAL_UPSTREAM_BEARER_TOKEN",
os.getenv("SREBOT_API_BEARER_TOKEN", "")).strip()
if upstream_token:
headers["Authorization"] = f"Bearer {upstream_token}"
try:
async with request.app["http_session"].request(
request.method, target, headers=headers, data=body or None
) as upstream:
payload = await upstream.read()
out_headers = {
k: v for k, v in upstream.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
and k.lower() not in {"content-length", "content-encoding"}
}
# Rewrite info responses so endpoint paths include the channel prefix.
if tail == "info":
try:
body_obj = json.loads(payload)
prefix = f"/api/{channel}"
if "endpoints" in body_obj:
rewritten = {}
for method_path, desc in body_obj["endpoints"].items():
parts = method_path.split(" ", 1)
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
rewritten[f"{parts[0]} /api/{channel}{parts[1][4:]}"] = desc
else:
rewritten[method_path] = desc
body_obj["endpoints"] = rewritten
if "availableEndpoints" in body_obj:
rewritten = []
for ep in body_obj["availableEndpoints"]:
parts = ep.split(" ", 1)
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
rewritten.append(f"{parts[0]} /api/{channel}{parts[1][4:]}")
elif not parts[1:]:
rewritten.append(ep)
else:
rewritten.append(ep)
body_obj["availableEndpoints"] = rewritten
payload = json.dumps(body_obj, indent=2).encode()
except Exception:
pass
return web.Response(body=payload, status=upstream.status, headers=out_headers)
except aiohttp.ClientError as exc:
logger.warning("upstream error channel=%s: %s", channel, exc)
return web.json_response({"error": "upstream unavailable"}, status=502)
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
channel = request.match_info["channel"]
channels: dict[str, ChannelConfig] = request.app["channels"]
grant = _grant(request)
ws = web.WebSocketResponse(heartbeat=20)
if channel not in channels or grant is None or not can_access(grant, channel):
await ws.prepare(request)
await ws.close(code=1008, message=b"Unauthorized")
return ws
await ws.prepare(request)
clients: set[web.WebSocketResponse] = request.app["clients"][channel]
clients.add(ws)
logger.info("ws connected channel=%s clients=%d", channel, len(clients))
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
logger.warning("ws error: %s", ws.exception())
finally:
clients.discard(ws)
logger.info("ws disconnected channel=%s clients=%d", channel, len(clients))
return ws
async def _broadcast(app: web.Application, channel: str, envelope: dict[str, Any]) -> None:
raw = json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
payload = _compressor.compress(raw)
targets = list(app["clients"][channel])
if not targets:
return
dead = []
for ws in targets:
try:
await ws.send_bytes(payload)
except Exception:
dead.append(ws)
for ws in dead:
app["clients"][channel].discard(ws)
logger.info("broadcast channel=%s clients=%d raw=%d zstd=%d",
channel, len(targets) - len(dead), len(raw), len(payload))
def _read_offset(p: Path) -> int:
try:
return int(p.read_text(encoding="utf-8").strip())
except Exception:
return 0
def _write_offset(p: Path, offset: int) -> None:
p.write_text(str(offset), encoding="utf-8")
def _maybe_truncate(outbox: Path, position: int) -> int:
try:
size = outbox.stat().st_size
if position >= _TRUNCATE_BYTES and position == size:
with outbox.open("r+b") as handle:
handle.truncate(0)
_write_offset(outbox.with_suffix(".offset"), 0)
return 0
except FileNotFoundError:
pass
except Exception as exc:
logger.warning("truncate failed: %s", exc)
return position
async def relay_loop(app: web.Application, cfg: ChannelConfig) -> None:
outbox = cfg.outbox_path
offset_path = outbox.with_suffix(".offset")
position = _read_offset(offset_path)
delay = 1.0
while True:
try:
if not outbox.exists():
await asyncio.sleep(1.0)
continue
size = outbox.stat().st_size
if position > size:
position = 0
_write_offset(offset_path, position)
with outbox.open("r", encoding="utf-8") as handle:
handle.seek(position)
line = handle.readline()
if not line:
position = _maybe_truncate(outbox, position)
await asyncio.sleep(0.5)
continue
try:
envelope = json.loads(line)
except json.JSONDecodeError:
position = handle.tell()
_write_offset(offset_path, position)
continue
position = handle.tell()
_write_offset(offset_path, position)
await _broadcast(app, cfg.name, envelope)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("relay loop error channel=%s: %s", cfg.name, exc)
await asyncio.sleep(delay)
delay = min(delay * 2, 30.0)
async def _http_session_ctx(app: web.Application):
app["http_session"] = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
try:
yield
finally:
await app["http_session"].close()
async def _relay_tasks_ctx(app: web.Application):
tasks = [asyncio.create_task(relay_loop(app, cfg))
for cfg in app["channels"].values()]
try:
yield
finally:
for t in tasks:
t.cancel()
for t in tasks:
try:
await t
except asyncio.CancelledError:
pass
def create_app(*, key_store: KeyStore, channels: dict[str, ChannelConfig]) -> web.Application:
app = web.Application(middlewares=[auth_middleware])
app["key_store"] = key_store
app["channels"] = channels
app["clients"] = {name: set() for name in channels}
app.router.add_get("/", root)
app.router.add_get("/health", health)
app.router.add_get("/api/whoami", whoami)
app.router.add_get("/ws/{channel}", websocket_handler)
app.router.add_route("*", "/api/{channel}/{tail:.*}", proxy_api)
app.cleanup_ctx.append(_http_session_ctx)
app.cleanup_ctx.append(_relay_tasks_ctx)
return app
def main() -> None:
from dotenv import load_dotenv
# The gateway runs from BOTS/SHARED; load SHARED/.env for gateway config.
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
logging.basicConfig(level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] [relay-gateway] %(message)s")
storage_root = Path(os.environ["STORAGE_VOL_PATH"])
key_path = Path(os.getenv("RELAY_KEYS_PATH", str(storage_root / "relay_keys.json")))
channels = build_channels(
storage_root,
sre_upstream=os.getenv("SREBOT_EXTERNAL_UPSTREAM_URL", "http://127.0.0.1:6000"),
tss_upstream=os.getenv("TSS_EXTERNAL_UPSTREAM_URL") or None,
)
app = create_app(key_store=KeyStore(key_path), channels=channels)
web.run_app(
app,
host=os.getenv("SREBOT_EXTERNAL_HOST", "0.0.0.0"),
port=int(os.getenv("SREBOT_EXTERNAL_PORT", "18081")),
)
if __name__ == "__main__":
main()