Initial commit: SHARED library with LFS for binary assets
This commit is contained in:
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ChannelConfig:
|
||||
name: str
|
||||
outbox_path: Path
|
||||
upstream_url: str | None # None => /api/<channel>/* returns 501
|
||||
upstream_path: str = "/api/" # path prefix when proxying (channel injected via {channel})
|
||||
|
||||
|
||||
def build_channels(
|
||||
storage_root: Path,
|
||||
sre_upstream: str,
|
||||
tss_upstream: str | None,
|
||||
) -> dict[str, ChannelConfig]:
|
||||
return {
|
||||
"sre": ChannelConfig(
|
||||
name="sre",
|
||||
outbox_path=storage_root / "external_bridge_outbox.jsonl",
|
||||
upstream_url=sre_upstream.rstrip("/"),
|
||||
upstream_path="/api/",
|
||||
),
|
||||
"tss": ChannelConfig(
|
||||
name="tss",
|
||||
outbox_path=storage_root / "tss_bridge_outbox.jsonl",
|
||||
upstream_url=(tss_upstream.rstrip("/") if tss_upstream else None),
|
||||
upstream_path="/api/tss/",
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import zstandard as zstd
|
||||
from aiohttp import web
|
||||
|
||||
from relay_gateway.keys import KeyStore, Grant, can_access
|
||||
from relay_gateway.channels import ChannelConfig, build_channels
|
||||
|
||||
logger = logging.getLogger("relay-gateway")
|
||||
|
||||
HOP_BY_HOP_HEADERS = {
|
||||
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
|
||||
"te", "trailers", "transfer-encoding", "upgrade",
|
||||
}
|
||||
_compressor = zstd.ZstdCompressor(level=3)
|
||||
_TRUNCATE_BYTES = int(os.getenv("SREBOT_EXTERNAL_TRUNCATE_BYTES", str(100 * 1024 * 1024)))
|
||||
|
||||
|
||||
def _grant(request: web.Request) -> Grant | None:
|
||||
store: KeyStore = request.app["key_store"]
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth[7:] if auth.startswith("Bearer ") else ""
|
||||
return store.resolve(token)
|
||||
|
||||
|
||||
@web.middleware
|
||||
async def auth_middleware(request: web.Request, handler):
|
||||
if request.path in {"/health", "/"}:
|
||||
return await handler(request)
|
||||
grant = _grant(request)
|
||||
if grant is None:
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
request["grant"] = grant
|
||||
return await handler(request)
|
||||
|
||||
|
||||
async def health(_: web.Request) -> web.Response:
|
||||
return web.json_response({"status": "ok", "service": "relay-gateway"})
|
||||
|
||||
|
||||
async def root(_: web.Request) -> web.Response:
|
||||
return web.json_response({"service": "relay-gateway",
|
||||
"message": "Use /api/<sre|tss>/* and /ws/<sre|tss>."})
|
||||
|
||||
|
||||
async def whoami(request: web.Request) -> web.Response:
|
||||
grant: Grant = request["grant"]
|
||||
return web.json_response(
|
||||
{"name": grant.name, "level": grant.level, "channels": list(grant.channels)}
|
||||
)
|
||||
|
||||
|
||||
async def proxy_api(request: web.Request) -> web.StreamResponse:
|
||||
channel = request.match_info["channel"]
|
||||
channels: dict[str, ChannelConfig] = request.app["channels"]
|
||||
grant: Grant = request["grant"]
|
||||
if channel not in channels:
|
||||
if channel == "sqb":
|
||||
return web.json_response(
|
||||
{
|
||||
"error": f"unknown channel '{channel}'",
|
||||
"detail": "The 'sqb' channel was renamed to 'sre'. Use /api/sre/ instead, and register new API keys with the 'sre' level (legacy 'sqb'-level keys still work as an alias).",
|
||||
},
|
||||
status=404,
|
||||
)
|
||||
return web.json_response({"error": f"unknown channel {channel}"}, status=404)
|
||||
if not can_access(grant, channel):
|
||||
return web.json_response({"error": "Forbidden for this key"}, status=403)
|
||||
cfg = channels[channel]
|
||||
if cfg.upstream_url is None:
|
||||
return web.json_response(
|
||||
{"error": f"/api/{channel}/* not implemented yet"}, status=501
|
||||
)
|
||||
tail = request.match_info["tail"]
|
||||
target = f"{cfg.upstream_url}{cfg.upstream_path}{tail}"
|
||||
if request.query_string:
|
||||
target = f"{target}?{request.query_string}"
|
||||
body = await request.read() if request.can_read_body else b""
|
||||
headers = {"Accept": "application/json"}
|
||||
upstream_token = os.getenv("SREBOT_EXTERNAL_UPSTREAM_BEARER_TOKEN",
|
||||
os.getenv("SREBOT_API_BEARER_TOKEN", "")).strip()
|
||||
if upstream_token:
|
||||
headers["Authorization"] = f"Bearer {upstream_token}"
|
||||
try:
|
||||
async with request.app["http_session"].request(
|
||||
request.method, target, headers=headers, data=body or None
|
||||
) as upstream:
|
||||
payload = await upstream.read()
|
||||
out_headers = {
|
||||
k: v for k, v in upstream.headers.items()
|
||||
if k.lower() not in HOP_BY_HOP_HEADERS
|
||||
and k.lower() not in {"content-length", "content-encoding"}
|
||||
}
|
||||
# Rewrite info responses so endpoint paths include the channel prefix.
|
||||
if tail == "info":
|
||||
try:
|
||||
body_obj = json.loads(payload)
|
||||
prefix = f"/api/{channel}"
|
||||
if "endpoints" in body_obj:
|
||||
rewritten = {}
|
||||
for method_path, desc in body_obj["endpoints"].items():
|
||||
parts = method_path.split(" ", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
|
||||
rewritten[f"{parts[0]} /api/{channel}{parts[1][4:]}"] = desc
|
||||
else:
|
||||
rewritten[method_path] = desc
|
||||
body_obj["endpoints"] = rewritten
|
||||
if "availableEndpoints" in body_obj:
|
||||
rewritten = []
|
||||
for ep in body_obj["availableEndpoints"]:
|
||||
parts = ep.split(" ", 1)
|
||||
if len(parts) == 2 and parts[1].startswith("/api/") and not parts[1].startswith(f"/api/{channel}/"):
|
||||
rewritten.append(f"{parts[0]} /api/{channel}{parts[1][4:]}")
|
||||
elif not parts[1:]:
|
||||
rewritten.append(ep)
|
||||
else:
|
||||
rewritten.append(ep)
|
||||
body_obj["availableEndpoints"] = rewritten
|
||||
payload = json.dumps(body_obj, indent=2).encode()
|
||||
except Exception:
|
||||
pass
|
||||
return web.Response(body=payload, status=upstream.status, headers=out_headers)
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.warning("upstream error channel=%s: %s", channel, exc)
|
||||
return web.json_response({"error": "upstream unavailable"}, status=502)
|
||||
|
||||
|
||||
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
||||
channel = request.match_info["channel"]
|
||||
channels: dict[str, ChannelConfig] = request.app["channels"]
|
||||
grant = _grant(request)
|
||||
ws = web.WebSocketResponse(heartbeat=20)
|
||||
if channel not in channels or grant is None or not can_access(grant, channel):
|
||||
await ws.prepare(request)
|
||||
await ws.close(code=1008, message=b"Unauthorized")
|
||||
return ws
|
||||
await ws.prepare(request)
|
||||
clients: set[web.WebSocketResponse] = request.app["clients"][channel]
|
||||
clients.add(ws)
|
||||
logger.info("ws connected channel=%s clients=%d", channel, len(clients))
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logger.warning("ws error: %s", ws.exception())
|
||||
finally:
|
||||
clients.discard(ws)
|
||||
logger.info("ws disconnected channel=%s clients=%d", channel, len(clients))
|
||||
return ws
|
||||
|
||||
|
||||
async def _broadcast(app: web.Application, channel: str, envelope: dict[str, Any]) -> None:
|
||||
raw = json.dumps(envelope, ensure_ascii=False, separators=(",", ":")).encode("utf-8")
|
||||
payload = _compressor.compress(raw)
|
||||
targets = list(app["clients"][channel])
|
||||
if not targets:
|
||||
return
|
||||
dead = []
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.send_bytes(payload)
|
||||
except Exception:
|
||||
dead.append(ws)
|
||||
for ws in dead:
|
||||
app["clients"][channel].discard(ws)
|
||||
logger.info("broadcast channel=%s clients=%d raw=%d zstd=%d",
|
||||
channel, len(targets) - len(dead), len(raw), len(payload))
|
||||
|
||||
|
||||
def _read_offset(p: Path) -> int:
|
||||
try:
|
||||
return int(p.read_text(encoding="utf-8").strip())
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _write_offset(p: Path, offset: int) -> None:
|
||||
p.write_text(str(offset), encoding="utf-8")
|
||||
|
||||
|
||||
def _maybe_truncate(outbox: Path, position: int) -> int:
|
||||
try:
|
||||
size = outbox.stat().st_size
|
||||
if position >= _TRUNCATE_BYTES and position == size:
|
||||
with outbox.open("r+b") as handle:
|
||||
handle.truncate(0)
|
||||
_write_offset(outbox.with_suffix(".offset"), 0)
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("truncate failed: %s", exc)
|
||||
return position
|
||||
|
||||
|
||||
async def relay_loop(app: web.Application, cfg: ChannelConfig) -> None:
|
||||
outbox = cfg.outbox_path
|
||||
offset_path = outbox.with_suffix(".offset")
|
||||
position = _read_offset(offset_path)
|
||||
delay = 1.0
|
||||
while True:
|
||||
try:
|
||||
if not outbox.exists():
|
||||
await asyncio.sleep(1.0)
|
||||
continue
|
||||
size = outbox.stat().st_size
|
||||
if position > size:
|
||||
position = 0
|
||||
_write_offset(offset_path, position)
|
||||
with outbox.open("r", encoding="utf-8") as handle:
|
||||
handle.seek(position)
|
||||
line = handle.readline()
|
||||
if not line:
|
||||
position = _maybe_truncate(outbox, position)
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
try:
|
||||
envelope = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
position = handle.tell()
|
||||
_write_offset(offset_path, position)
|
||||
continue
|
||||
position = handle.tell()
|
||||
_write_offset(offset_path, position)
|
||||
await _broadcast(app, cfg.name, envelope)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("relay loop error channel=%s: %s", cfg.name, exc)
|
||||
await asyncio.sleep(delay)
|
||||
delay = min(delay * 2, 30.0)
|
||||
|
||||
|
||||
async def _http_session_ctx(app: web.Application):
|
||||
app["http_session"] = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await app["http_session"].close()
|
||||
|
||||
|
||||
async def _relay_tasks_ctx(app: web.Application):
|
||||
tasks = [asyncio.create_task(relay_loop(app, cfg))
|
||||
for cfg in app["channels"].values()]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
def create_app(*, key_store: KeyStore, channels: dict[str, ChannelConfig]) -> web.Application:
|
||||
app = web.Application(middlewares=[auth_middleware])
|
||||
app["key_store"] = key_store
|
||||
app["channels"] = channels
|
||||
app["clients"] = {name: set() for name in channels}
|
||||
app.router.add_get("/", root)
|
||||
app.router.add_get("/health", health)
|
||||
app.router.add_get("/api/whoami", whoami)
|
||||
app.router.add_get("/ws/{channel}", websocket_handler)
|
||||
app.router.add_route("*", "/api/{channel}/{tail:.*}", proxy_api)
|
||||
app.cleanup_ctx.append(_http_session_ctx)
|
||||
app.cleanup_ctx.append(_relay_tasks_ctx)
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
from dotenv import load_dotenv
|
||||
# The gateway runs from BOTS/SHARED; load SHARED/.env for gateway config.
|
||||
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="[%(asctime)s] [%(levelname)s] [relay-gateway] %(message)s")
|
||||
storage_root = Path(os.environ["STORAGE_VOL_PATH"])
|
||||
key_path = Path(os.getenv("RELAY_KEYS_PATH", str(storage_root / "relay_keys.json")))
|
||||
channels = build_channels(
|
||||
storage_root,
|
||||
sre_upstream=os.getenv("SREBOT_EXTERNAL_UPSTREAM_URL", "http://127.0.0.1:6000"),
|
||||
tss_upstream=os.getenv("TSS_EXTERNAL_UPSTREAM_URL") or None,
|
||||
)
|
||||
app = create_app(key_store=KeyStore(key_path), channels=channels)
|
||||
web.run_app(
|
||||
app,
|
||||
host=os.getenv("SREBOT_EXTERNAL_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("SREBOT_EXTERNAL_PORT", "18081")),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger("relay-gateway.keys")
|
||||
|
||||
LEVELS = ("all", "sre", "sqb", "tss")
|
||||
CHANNELS_FOR_LEVEL: dict[str, tuple[str, ...]] = {
|
||||
"all": ("sre", "tss"),
|
||||
"sre": ("sre",),
|
||||
"sqb": ("sre",), # legacy alias
|
||||
"tss": ("tss",),
|
||||
}
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
return hashlib.sha256(token.strip().encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class Grant:
|
||||
name: str
|
||||
level: str
|
||||
channels: tuple[str, ...]
|
||||
|
||||
|
||||
def can_access(grant: Grant, channel: str) -> bool:
|
||||
return channel in grant.channels
|
||||
|
||||
|
||||
class KeyStore:
|
||||
"""Loads relay_keys.json (hashed tokens) and reloads on mtime change."""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = Path(path)
|
||||
self._mtime: float | None = None
|
||||
self._by_hash: dict[str, Grant] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
try:
|
||||
stat = self._path.stat()
|
||||
except FileNotFoundError:
|
||||
self._by_hash = {}
|
||||
self._mtime = None
|
||||
return
|
||||
if self._mtime is not None and stat.st_mtime == self._mtime:
|
||||
return
|
||||
try:
|
||||
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
logger.warning("Failed to read key store: %s", exc)
|
||||
return
|
||||
parsed: dict[str, Grant] = {}
|
||||
for token_hash, meta in (raw or {}).items():
|
||||
level = str(meta.get("level", "")).strip()
|
||||
if level not in CHANNELS_FOR_LEVEL:
|
||||
logger.warning("Skipping key %s: bad level %r", meta.get("name"), level)
|
||||
continue
|
||||
parsed[token_hash] = Grant(
|
||||
name=str(meta.get("name", "unnamed")),
|
||||
level=level,
|
||||
channels=CHANNELS_FOR_LEVEL[level],
|
||||
)
|
||||
self._by_hash = parsed
|
||||
self._mtime = stat.st_mtime
|
||||
|
||||
def resolve(self, token: str) -> Grant | None:
|
||||
self._load()
|
||||
if not token:
|
||||
return None
|
||||
return self._by_hash.get(hash_token(token))
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
from relay_gateway.keys import hash_token, CHANNELS_FOR_LEVEL
|
||||
|
||||
|
||||
def _read(path: Path) -> dict:
|
||||
try:
|
||||
return json.loads(Path(path).read_text(encoding="utf-8"))
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
|
||||
def _write(path: Path, body: dict) -> None:
|
||||
Path(path).write_text(json.dumps(body, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def add_key(path: Path, *, name: str, level: str) -> str:
|
||||
if level not in CHANNELS_FOR_LEVEL:
|
||||
raise ValueError(f"level must be one of {sorted(CHANNELS_FOR_LEVEL)}")
|
||||
token = secrets.token_urlsafe(32)
|
||||
body = _read(path)
|
||||
body[hash_token(token)] = {"name": name, "level": level}
|
||||
_write(path, body)
|
||||
return token
|
||||
|
||||
|
||||
def list_keys(path: Path) -> list[dict]:
|
||||
return [{"hash": h, **meta} for h, meta in _read(path).items()]
|
||||
|
||||
|
||||
def revoke(path: Path, name: str) -> int:
|
||||
body = _read(path)
|
||||
to_remove = [h for h, meta in body.items() if meta.get("name") == name]
|
||||
for h in to_remove:
|
||||
del body[h]
|
||||
_write(path, body)
|
||||
return len(to_remove)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="Manage relay gateway keys")
|
||||
ap.add_argument("--file", required=True, type=Path)
|
||||
sub = ap.add_subparsers(dest="cmd", required=True)
|
||||
a = sub.add_parser("add"); a.add_argument("--name", required=True); a.add_argument("--level", required=True)
|
||||
sub.add_parser("list")
|
||||
r = sub.add_parser("revoke"); r.add_argument("--name", required=True)
|
||||
args = ap.parse_args()
|
||||
if args.cmd == "add":
|
||||
token = add_key(args.file, name=args.name, level=args.level)
|
||||
print(f"Token for {args.name!r} (level={args.level}) — store it now, shown once:\n{token}")
|
||||
elif args.cmd == "list":
|
||||
for e in list_keys(args.file):
|
||||
print(f"{e['name']:20s} {e['level']:4s} {e['hash'][:12]}…")
|
||||
elif args.cmd == "revoke":
|
||||
print(f"Removed {revoke(args.file, args.name)} key(s) named {args.name!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,79 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from relay_gateway.keys import KeyStore, hash_token
|
||||
from relay_gateway.channels import build_channels
|
||||
from relay_gateway.gateway import create_app
|
||||
|
||||
|
||||
def _keyfile(tmp_path: Path, tokens: dict[str, dict]) -> Path:
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
kf.write_text(
|
||||
json.dumps({hash_token(t): m for t, m in tokens.items()}), encoding="utf-8"
|
||||
)
|
||||
return kf
|
||||
|
||||
|
||||
async def _client(tmp_path, tokens, *, tss_upstream=None) -> TestClient:
|
||||
kf = _keyfile(tmp_path, tokens)
|
||||
channels = build_channels(tmp_path, sre_upstream="http://127.0.0.1:1", tss_upstream=tss_upstream)
|
||||
app = create_app(key_store=KeyStore(kf), channels=channels)
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
return client
|
||||
|
||||
|
||||
async def test_health_is_open(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "all"}})
|
||||
resp = await client.get("/health")
|
||||
assert resp.status == 200
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_missing_token_is_401(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "sre"}})
|
||||
resp = await client.get("/api/sre/info")
|
||||
assert resp.status == 401
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_sqb_key_denied_tss_is_403(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "sre"}})
|
||||
resp = await client.get("/api/tss/info", headers={"Authorization": "Bearer k"})
|
||||
assert resp.status == 403
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_tss_proxy_501_when_no_upstream(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "tss"}})
|
||||
resp = await client.get("/api/tss/info", headers={"Authorization": "Bearer k"})
|
||||
assert resp.status == 501
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_tss_proxy_not_501_with_upstream(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "tss"}},
|
||||
tss_upstream="http://127.0.0.1:6100")
|
||||
resp = await client.get("/api/tss/info", headers={"Authorization": "Bearer k"})
|
||||
assert resp.status != 501 # proxied (likely 502 no-connection in test, never 501)
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_whoami_returns_grant(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "cn", "level": "sre"}})
|
||||
resp = await client.get("/api/whoami", headers={"Authorization": "Bearer k"})
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body == {"name": "cn", "level": "sre", "channels": ["sre"]}
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_ws_rejects_bad_token(tmp_path):
|
||||
client = await _client(tmp_path, {"k": {"name": "n", "level": "tss"}})
|
||||
ws = await client.ws_connect("/ws/sre", headers={"Authorization": "Bearer k"})
|
||||
msg = await ws.receive()
|
||||
assert msg.type.name in {"CLOSE", "CLOSED", "CLOSING"}
|
||||
await client.close()
|
||||
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from relay_gateway.keys import KeyStore, hash_token, can_access, CHANNELS_FOR_LEVEL
|
||||
|
||||
|
||||
def _write(path: Path, tokens: dict[str, dict]) -> None:
|
||||
body = {hash_token(tok): meta for tok, meta in tokens.items()}
|
||||
path.write_text(json.dumps(body), encoding="utf-8")
|
||||
|
||||
|
||||
def test_resolve_returns_grant_with_derived_channels(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"sekret": {"name": "cn-axbot", "level": "sre"}})
|
||||
store = KeyStore(kf)
|
||||
grant = store.resolve("sekret")
|
||||
assert grant is not None
|
||||
assert grant.name == "cn-axbot"
|
||||
assert grant.level == "sre"
|
||||
assert grant.channels == ("sre",)
|
||||
|
||||
|
||||
def test_resolve_unknown_token_is_none(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"sekret": {"name": "x", "level": "all"}})
|
||||
store = KeyStore(kf)
|
||||
assert store.resolve("nope") is None
|
||||
|
||||
|
||||
def test_all_level_grants_both_channels(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"k": {"name": "internal", "level": "all"}})
|
||||
grant = KeyStore(kf).resolve("k")
|
||||
assert grant is not None
|
||||
assert set(grant.channels) == {"sre", "tss"}
|
||||
|
||||
|
||||
def test_can_access_enforces_channel(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"k": {"name": "t", "level": "tss"}})
|
||||
grant = KeyStore(kf).resolve("k")
|
||||
assert grant is not None
|
||||
assert can_access(grant, "tss") is True
|
||||
assert can_access(grant, "sre") is False
|
||||
|
||||
|
||||
def test_hot_reload_on_mtime_change(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"k": {"name": "t", "level": "sre"}})
|
||||
store = KeyStore(kf)
|
||||
assert store.resolve("k").level == "sre"
|
||||
import os, time
|
||||
time.sleep(0.01)
|
||||
_write(kf, {"k": {"name": "t", "level": "tss"}})
|
||||
os.utime(kf, None)
|
||||
assert store.resolve("k").level == "tss"
|
||||
|
||||
|
||||
def test_missing_file_resolves_none(tmp_path):
|
||||
store = KeyStore(tmp_path / "absent.json")
|
||||
assert store.resolve("anything") is None
|
||||
|
||||
|
||||
def test_bad_level_is_skipped(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
_write(kf, {"k": {"name": "t", "level": "bogus"}})
|
||||
assert KeyStore(kf).resolve("k") is None
|
||||
assert "bogus" not in CHANNELS_FOR_LEVEL
|
||||
@@ -0,0 +1,29 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from relay_gateway.manage_keys import add_key, list_keys, revoke
|
||||
from relay_gateway.keys import hash_token
|
||||
|
||||
|
||||
def test_add_key_writes_hashed_entry(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
token = add_key(kf, name="cn", level="sre")
|
||||
body = json.loads(kf.read_text())
|
||||
assert hash_token(token) in body
|
||||
assert body[hash_token(token)] == {"name": "cn", "level": "sre"}
|
||||
|
||||
|
||||
def test_add_key_rejects_bad_level(tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
add_key(tmp_path / "k.json", name="x", level="bogus")
|
||||
|
||||
|
||||
def test_revoke_removes_by_name(tmp_path):
|
||||
kf = tmp_path / "relay_keys.json"
|
||||
add_key(kf, name="cn", level="sre")
|
||||
add_key(kf, name="keep", level="tss")
|
||||
removed = revoke(kf, "cn")
|
||||
assert removed == 1
|
||||
names = {e["name"] for e in list_keys(kf)}
|
||||
assert names == {"keep"}
|
||||
Reference in New Issue
Block a user