2145 lines
78 KiB
Python
2145 lines
78 KiB
Python
"""
|
||
utils.py
|
||
|
||
Shared utilities, constants, and permission helpers used across the bot.
|
||
"""
|
||
|
||
# Standard Library Imports
|
||
import asyncio
|
||
import base64
|
||
import gzip
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import sys
|
||
import time
|
||
import unicodedata
|
||
from datetime import datetime, time as dt_time, timedelta, timezone
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Iterable, List, Literal, Optional, TypedDict
|
||
|
||
# Third-Party Library Imports
|
||
import aiofiles
|
||
import aiosqlite
|
||
import discord
|
||
import zstandard
|
||
from discord import app_commands
|
||
from discord.ext import commands
|
||
from discord.utils import escape_markdown, escape_mentions
|
||
from dotenv import load_dotenv
|
||
from wcwidth import wcswidth
|
||
|
||
# Local Module Imports
|
||
# BOT/__init__.py has already put BOTS/SHARED on sys.path; re-export it
|
||
# under a public name so peer modules can use it for asset paths.
|
||
from . import SHARED_DIR # noqa: F401 — re-exported for siblings
|
||
from shared_store import check_user_blacklist, check_guild_blacklist, blacklisted_guilds
|
||
from data_parser import (
|
||
LangTableReader,
|
||
UnitTags,
|
||
apply_vehicle_name_filters,
|
||
normalize_name,
|
||
)
|
||
|
||
load_dotenv()
|
||
|
||
|
||
def require_storage_dir() -> Path:
|
||
"""Return the configured storage root or fail fast if misconfigured."""
|
||
raw = os.environ.get("STORAGE_VOL_PATH", "").strip()
|
||
if not raw:
|
||
raise RuntimeError("STORAGE_VOL_PATH must be set")
|
||
return Path(raw)
|
||
|
||
|
||
def esc(text: str) -> str:
|
||
"""Escape both markdown and mentions for safe embed/message display."""
|
||
return escape_mentions(escape_markdown(str(text)))
|
||
|
||
# ============================================================================
|
||
# CONSTANTS
|
||
# ============================================================================
|
||
|
||
# Base storage paths
|
||
STORAGE_DIR = require_storage_dir()
|
||
ICONS_DIR = SHARED_DIR / "ICONS"
|
||
|
||
# Cache and Auth directories
|
||
CACHE_DIR = STORAGE_DIR / "CACHE"
|
||
AUTH_DIR = STORAGE_DIR / "AUTH"
|
||
STACKS_DIR = STORAGE_DIR / "STACKS"
|
||
REPLAYS_DIR = STORAGE_DIR / "REPLAYS" / "SRE"
|
||
STORAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Databases
|
||
SQ_BATTLES_DB_PATH = STORAGE_DIR / "sq_battles.db"
|
||
SQUADRONS_DB_PATH = STORAGE_DIR / "squadrons.db"
|
||
WL_DB_PATH = STORAGE_DIR / "wl.db"
|
||
POINTS_DB_PATH = STORAGE_DIR / "points.db"
|
||
ENTITLEMENTS_DB_PATH = STORAGE_DIR / "entitlements.db"
|
||
COMMAND_DATA_DB_PATH = STORAGE_DIR / "COMMAND_DATA.db"
|
||
|
||
|
||
def replay_session_dir(session_id: str | int) -> Path:
|
||
"""Return the canonical on-disk replay directory for a hex session ID."""
|
||
session = str(session_id).strip().lower()
|
||
if session.startswith("0x"):
|
||
session = session[2:]
|
||
return REPLAYS_DIR / session
|
||
|
||
|
||
def replay_data_path(session_id: str | int) -> Path:
|
||
return replay_session_dir(session_id) / "replay_data.json.gz"
|
||
|
||
# Dev team Discord user IDs (bot owner + trusted devs)
|
||
DEV_DISCORD_IDS: set[int] = {
|
||
1357793112277127290,
|
||
396729572814749706,
|
||
383992974649982976,
|
||
621961120617594880,
|
||
386224084129939456,
|
||
}
|
||
|
||
# Default Strings
|
||
DEFAULT_FOOTER_CAT: str = "ᓚᘏᗢ"
|
||
|
||
# Discord Token
|
||
TOKEN = os.environ.get('DISCORD_KEY')
|
||
|
||
# ============================================================================
|
||
# JSON COMPRESSION HELPERS
|
||
# ============================================================================
|
||
|
||
def compress_json(obj, **kwargs) -> bytes:
|
||
"""Serialize obj to JSON and gzip-compress it for BLOB storage."""
|
||
return gzip.compress(json.dumps(obj, **kwargs).encode("utf-8"), compresslevel=6)
|
||
|
||
|
||
def decompress_json(data):
|
||
"""Parse a JSON column that may be gzip-compressed (bytes) or plain TEXT (str)."""
|
||
if isinstance(data, (bytes, memoryview)):
|
||
return json.loads(gzip.decompress(bytes(data)))
|
||
return json.loads(data)
|
||
|
||
|
||
# ============================================================================
|
||
# BLACKLISTS
|
||
# ============================================================================
|
||
|
||
# Blacklisted users, squadrons, and guilds now live in the shared,
|
||
# version-controlled BOTS/SHARED/BLACKLIST.json (read via shared_store) so both
|
||
# bots share one source of truth. Use check_user_blacklist() /
|
||
# shared_store.blacklisted_squadrons() / shared_store.blacklisted_guilds().
|
||
|
||
# ── Premium / Entitlements ────────────────────────────────────────────────────
|
||
PREMIUM_ACTIVATION_TS: int = 1775459700 # Unix timestamp when premium gating activates
|
||
COMP_FREE_UNTIL_TS: int = 1777620600 # Free /comp period ends
|
||
|
||
# ── Tier enforcement (activation timestamp updated per policy) ───────────────
|
||
TIER_ENFORCEMENT_TS: int = 1778107232
|
||
|
||
Tier = Literal["standard", "pro", "max"]
|
||
|
||
TIER_NOTIF_CAPS: Dict[str, Dict[str, Optional[int]]] = {
|
||
"standard": {"Logs": 10, "Points": 10, "Leave": 10},
|
||
"pro": {"Logs": 25, "Points": 25, "Leave": 25},
|
||
"max": {"Logs": None, "Points": None, "Leave": None},
|
||
}
|
||
|
||
TIER_ALLOWS_WILDCARDS: Dict[str, bool] = {
|
||
"standard": False,
|
||
"pro": True,
|
||
"max": True,
|
||
}
|
||
|
||
TIER_ORDER: List[str] = ["standard", "pro", "max"] # ascending; used for "highest wins"
|
||
WILDCARD_KEYS: set[str] = {"*", "all", "everything"}
|
||
|
||
# Env-var SKU IDs — populated once the Whop/Discord products exist.
|
||
# Standard SKU is kept as an int since discord.ui.Button(sku_id=...) needs an int.
|
||
DISCORD_SKU_ID_STANDARD: int = int(os.environ.get('DISCORD_SKU_ID_STANDARD', '1478970400158384220'))
|
||
DISCORD_SKU_ID_PRO: Optional[str] = os.environ.get("DISCORD_SKU_ID_PRO")
|
||
DISCORD_SKU_ID_MAX: Optional[str] = os.environ.get("DISCORD_SKU_ID_MAX")
|
||
|
||
|
||
def tier_cap(tier: Optional[str], notif_type: str) -> Optional[int]:
|
||
"""Enabled-entry cap for a (tier, notif_type). None = unlimited; None tier = 0."""
|
||
if tier is None:
|
||
return 0
|
||
if notif_type in ("Leaderboard", "Global", "WeeklyBR"):
|
||
return None
|
||
return TIER_NOTIF_CAPS.get(tier, {}).get(notif_type, 0)
|
||
|
||
|
||
def tier_allows_wildcard(tier: Optional[str]) -> bool:
|
||
return bool(tier) and TIER_ALLOWS_WILDCARDS.get(tier or "", False)
|
||
|
||
|
||
def tier_enforcement_active(now: Optional[float] = None) -> bool:
|
||
return (now if now is not None else time.time()) >= TIER_ENFORCEMENT_TS
|
||
|
||
|
||
def sku_id_to_tier(sku_id: Optional[str]) -> Optional[str]:
|
||
"""Map a Discord SKU ID to a tier, or None if unknown."""
|
||
if not sku_id:
|
||
return None
|
||
s = str(sku_id)
|
||
if s == DISCORD_SKU_ID_MAX:
|
||
return "max"
|
||
if s == DISCORD_SKU_ID_PRO:
|
||
return "pro"
|
||
if s == str(DISCORD_SKU_ID_STANDARD):
|
||
return "standard"
|
||
return None
|
||
|
||
|
||
def _tier_rank(tier: Optional[str]) -> int:
|
||
try:
|
||
return TIER_ORDER.index(tier) if tier else -1
|
||
except ValueError:
|
||
return -1
|
||
|
||
|
||
def higher_tier(a: Optional[str], b: Optional[str]) -> Optional[str]:
|
||
"""Return whichever of a/b has the higher rank in TIER_ORDER. None/unknown < all."""
|
||
ra, rb = _tier_rank(a), _tier_rank(b)
|
||
if ra < 0 and rb < 0:
|
||
return None
|
||
return a if ra >= rb else b
|
||
# Free-tier /comp caps per timeslot.
|
||
# Server-wide cap counts every invocation in a non-premium guild during the
|
||
# window. Per-user cap counts each user's invocations across ALL non-premium
|
||
# guilds in the window — premium-guild usage never counts toward either cap,
|
||
# so subscribers (and their members) bypass both checks entirely.
|
||
COMP_LIMIT_PER_TIMESLOT: int = 25
|
||
COMP_LIMIT_PER_USER_PER_TIMESLOT: int = 15
|
||
|
||
# ── SQB schedule (UTC, DST-immune) ───────────────────────────────────────────
|
||
# Edit SQB_SLOTS_POSTED and the margin constants when Gaijin changes the
|
||
# season schedule; every downstream time derives from these.
|
||
SQB_SLOTS_POSTED: List[tuple[str, dt_time, dt_time]] = [
|
||
("EU", dt_time(14, 0), dt_time(22, 0)),
|
||
("NA", dt_time(1, 0), dt_time(7, 0)),
|
||
]
|
||
# Hourly squadron-points snapshot window (consumed by squadron_stats_tracker_task in tasks.py)
|
||
SQB_STATS_TRACKER_PRE_MIN: int = 5 # snapshot window opens this many minutes before posted start
|
||
SQB_STATS_TRACKER_POST_MIN: int = 10 # snapshot window closes this many minutes after posted end
|
||
|
||
# Per-minute boundary snapshot ticks (consumed by squadron_stats_boundary_task in tasks.py)
|
||
SQB_BOUNDARY_PRE_MIN: int = 10 # boundary tick this many minutes before posted start
|
||
SQB_BOUNDARY_POST_MIN: int = 30 # boundary tick this many minutes after posted end
|
||
|
||
# /comp rate-limit window (consumed by get_current_timeslot_start_ts, enforced in botscript.py)
|
||
SQB_COMP_LIMIT_PRE_MIN: int = 0 # /comp limit window opens this many minutes before posted start
|
||
SQB_COMP_LIMIT_POST_MIN: int = 20 # /comp limit window closes this many minutes after posted end
|
||
|
||
def _shift_time(t: dt_time, minutes: int) -> dt_time:
|
||
total = (t.hour * 60 + t.minute + minutes) % (24 * 60)
|
||
return dt_time(total // 60, total % 60)
|
||
|
||
SQB_STATS_TRACKER_WINDOWS: List[tuple[str, dt_time, dt_time]] = [
|
||
(name, _shift_time(s, -SQB_STATS_TRACKER_PRE_MIN), _shift_time(e, SQB_STATS_TRACKER_POST_MIN))
|
||
for name, s, e in SQB_SLOTS_POSTED
|
||
]
|
||
|
||
SQB_BOUNDARY_TIMES: List[dt_time] = [
|
||
edge
|
||
for _, posted_start, posted_end in SQB_SLOTS_POSTED
|
||
for edge in (
|
||
_shift_time(posted_start, -SQB_BOUNDARY_PRE_MIN),
|
||
_shift_time(posted_end, SQB_BOUNDARY_POST_MIN),
|
||
)
|
||
]
|
||
|
||
# Cached guild_id → tier (highest rank among active sources); rebuilt with a TTL guard
|
||
_entitled_tiers: Dict[int, str] = {}
|
||
_entitled_guilds_ts: float = 0.0
|
||
_ENTITLEMENT_CACHE_TTL: float = 60.0 # seconds
|
||
|
||
|
||
def _merge_tier(acc: Dict[int, str], guild_id: int, tier: Optional[str]) -> None:
|
||
"""Upsert guild_id → tier into acc, keeping the higher-ranked tier."""
|
||
t = tier or "standard" # defensive: unknown/NULL → standard
|
||
current = acc.get(guild_id)
|
||
if current is None:
|
||
acc[guild_id] = t
|
||
else:
|
||
acc[guild_id] = higher_tier(current, t) or current
|
||
|
||
|
||
async def refresh_entitled_guilds(*, force: bool = False) -> None:
|
||
"""Rebuild the entitled-tiers cache from all three sources.
|
||
|
||
Skips the refresh if the cache was populated less than
|
||
_ENTITLEMENT_CACHE_TTL seconds ago, unless *force* is True.
|
||
"""
|
||
global _entitled_tiers, _entitled_guilds_ts
|
||
|
||
if not force and _entitled_tiers and (time.monotonic() - _entitled_guilds_ts) < _ENTITLEMENT_CACHE_TTL:
|
||
return
|
||
|
||
result: Dict[int, str] = {}
|
||
|
||
# 1. Whop (guild_entitlements)
|
||
try:
|
||
async with aiosqlite.connect(ENTITLEMENTS_DB_PATH) as db:
|
||
cur = await db.execute(
|
||
"SELECT guild_id, tier FROM guild_entitlements WHERE status='active'"
|
||
)
|
||
rows = await cur.fetchall()
|
||
for (gid, tier) in rows:
|
||
try:
|
||
_merge_tier(result, int(gid), tier)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to load Whop entitlements: {e}")
|
||
|
||
# 2. Manual entitlements
|
||
try:
|
||
async with aiosqlite.connect(ENTITLEMENTS_DB_PATH) as db:
|
||
cur = await db.execute(
|
||
"SELECT guild_id, tier FROM manual_entitlements WHERE expires_at > ?",
|
||
[int(time.time())],
|
||
)
|
||
rows = await cur.fetchall()
|
||
for (gid, tier) in rows:
|
||
try:
|
||
_merge_tier(result, int(gid), tier)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to load manual entitlements: {e}")
|
||
|
||
# 3. Discord native SKU entitlements — also sync to DB for web API
|
||
# Collapse to one row per guild, keeping the highest tier. A guild can hold
|
||
# multiple active Discord entitlements (e.g. during Pro→Max upgrade or if the
|
||
# owner bought two SKUs). The web-side `discord_entitlements` table has
|
||
# guild_id as PRIMARY KEY, so duplicates would collide — dedup first.
|
||
best_per_guild: Dict[str, tuple[str, str]] = {} # guild_id → (sku_id, tier)
|
||
discord_fetch_failed = False
|
||
try:
|
||
bot = get_bot()
|
||
async for ent in bot.entitlements(exclude_ended=True):
|
||
if not ent.guild_id:
|
||
continue
|
||
gid = str(ent.guild_id)
|
||
sku_id = str(getattr(ent, "sku_id", "") or "") or str(DISCORD_SKU_ID_STANDARD)
|
||
tier = sku_id_to_tier(sku_id) or "standard"
|
||
existing = best_per_guild.get(gid)
|
||
if existing is None or _tier_rank(tier) > _tier_rank(existing[1]):
|
||
best_per_guild[gid] = (sku_id, tier)
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to load Discord entitlements: {e}")
|
||
discord_fetch_failed = True
|
||
|
||
# Resilience guard: if the Discord API failed OR returned a suspiciously
|
||
# small set vs the last successful sync, fall back to the cached
|
||
# discord_entitlements table rather than wholesale-replacing it. This
|
||
# prevents Discord outages from silently demoting paying guilds.
|
||
try:
|
||
async with aiosqlite.connect(ENTITLEMENTS_DB_PATH) as db:
|
||
await db.execute(
|
||
"CREATE TABLE IF NOT EXISTS discord_entitlements "
|
||
"(guild_id TEXT PRIMARY KEY, sku_id TEXT, tier TEXT, "
|
||
"updated_at INTEGER DEFAULT (strftime('%s','now')))"
|
||
)
|
||
cur = await db.execute("SELECT COUNT(*) FROM discord_entitlements")
|
||
row = await cur.fetchone()
|
||
prior_count = int(row[0]) if row else 0
|
||
new_count = len(best_per_guild)
|
||
degraded = (
|
||
discord_fetch_failed
|
||
or (prior_count > 0 and new_count == 0)
|
||
or (prior_count >= 4 and new_count < prior_count // 2)
|
||
)
|
||
if degraded:
|
||
logging.warning(
|
||
f"[PREMIUM] Discord entitlements result looks degraded "
|
||
f"(api_failed={discord_fetch_failed}, new={new_count}, prior={prior_count}); "
|
||
f"keeping existing discord_entitlements table"
|
||
)
|
||
cur2 = await db.execute("SELECT guild_id, tier FROM discord_entitlements")
|
||
async for (gid_str, tier_str) in cur2:
|
||
try:
|
||
_merge_tier(result, int(gid_str), tier_str)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
else:
|
||
for gid_str, (_sku_id, tier_str) in best_per_guild.items():
|
||
try:
|
||
_merge_tier(result, int(gid_str), tier_str)
|
||
except (ValueError, TypeError):
|
||
pass
|
||
await db.execute("DELETE FROM discord_entitlements")
|
||
for gid_str, (sku_id, tier_str) in best_per_guild.items():
|
||
await db.execute(
|
||
"INSERT INTO discord_entitlements (guild_id, sku_id, tier) VALUES (?, ?, ?)",
|
||
[gid_str, sku_id, tier_str],
|
||
)
|
||
await db.commit()
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to sync discord_entitlements table: {e}")
|
||
|
||
_entitled_tiers = result
|
||
_entitled_guilds_ts = time.monotonic()
|
||
logging.info(f"[PREMIUM] Refreshed entitlement cache: {len(result)} guilds entitled")
|
||
|
||
|
||
def invalidate_entitled_guilds_cache() -> None:
|
||
"""Reset the TTL so the next refresh_entitled_guilds() call does a full reload."""
|
||
global _entitled_guilds_ts
|
||
_entitled_guilds_ts = 0.0
|
||
|
||
|
||
async def get_guild_tier(guild_id: int) -> Optional[str]:
|
||
"""Return 'standard' | 'pro' | 'max' or None if not entitled.
|
||
|
||
Uses the cache built by refresh_entitled_guilds(); falls back to a direct
|
||
DB + Discord-API lookup if the cache is empty.
|
||
"""
|
||
if _entitled_tiers:
|
||
return _entitled_tiers.get(guild_id)
|
||
|
||
best: Optional[str] = None
|
||
|
||
# Fallback: direct lookup
|
||
try:
|
||
async with aiosqlite.connect(ENTITLEMENTS_DB_PATH) as db:
|
||
cur = await db.execute(
|
||
"SELECT tier FROM guild_entitlements WHERE guild_id=? AND status='active'",
|
||
[str(guild_id)],
|
||
)
|
||
row = await cur.fetchone()
|
||
if row:
|
||
best = higher_tier(best, row[0] or "standard")
|
||
cur2 = await db.execute(
|
||
"SELECT tier FROM manual_entitlements WHERE guild_id=? AND expires_at > ?",
|
||
[str(guild_id), int(time.time())],
|
||
)
|
||
row2 = await cur2.fetchone()
|
||
if row2:
|
||
best = higher_tier(best, row2[0] or "standard")
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to check entitlement for {guild_id}: {e}")
|
||
|
||
try:
|
||
bot = get_bot()
|
||
async for ent in bot.entitlements(guild=discord.Object(id=guild_id), exclude_ended=True):
|
||
if ent.guild_id == guild_id:
|
||
sku_id = str(getattr(ent, "sku_id", "") or "")
|
||
best = higher_tier(best, sku_id_to_tier(sku_id) or "standard")
|
||
except Exception as e:
|
||
logging.error(f"[PREMIUM] Failed to check Discord entitlement for {guild_id}: {e}")
|
||
|
||
return best
|
||
|
||
|
||
async def is_guild_entitled(guild_id: int) -> bool:
|
||
"""Binary entitlement check — True if guild has any active tier."""
|
||
return (await get_guild_tier(guild_id)) is not None
|
||
# ─────────────────────────────────────────────────────────────────────────────
|
||
|
||
# ============================================================================
|
||
# BOT INSTANCE HOLDER
|
||
# ============================================================================
|
||
|
||
_bot_instance: Optional[commands.Bot] = None
|
||
|
||
|
||
def set_bot(bot: commands.Bot) -> None:
|
||
"""Register the bot instance globally."""
|
||
global _bot_instance
|
||
_bot_instance = bot
|
||
|
||
|
||
def get_bot() -> commands.Bot:
|
||
"""Retrieve the globally registered bot instance. Raises if not set."""
|
||
if _bot_instance is None:
|
||
raise RuntimeError("Bot instance not set. Call set_bot() first.")
|
||
return _bot_instance
|
||
|
||
|
||
# ============================================================================
|
||
# EXCEPTION CLASSES
|
||
# ============================================================================
|
||
|
||
class AdminCheckFailure(app_commands.CheckFailure):
|
||
"""Raised when a user lacks administrator permissions for a command."""
|
||
pass
|
||
|
||
|
||
class BlacklistCheckFailure(app_commands.CheckFailure):
|
||
"""Raised when a blacklisted user or guild attempts to run a command.
|
||
|
||
``reason`` carries the optional BLACKLIST.json reason; ``is_guild`` marks
|
||
whether the block came from the guild (vs the invoking user) so the error
|
||
handler can render a server-specific message.
|
||
"""
|
||
|
||
def __init__(self, reason: Optional[str] = None, *, is_guild: bool = False):
|
||
super().__init__(reason or "")
|
||
self.reason = reason
|
||
self.is_guild = is_guild
|
||
|
||
|
||
class TierGateFailure(app_commands.CheckFailure):
|
||
"""Raised when a guild's tier is below what a command requires.
|
||
|
||
The minimum required tier is carried on ``required_tier`` so
|
||
``permission_fail`` can render the matching localized embed.
|
||
"""
|
||
|
||
def __init__(self, required_tier: str, current_tier: Optional[str] = None):
|
||
super().__init__(f"This command requires the '{required_tier}' tier or higher.")
|
||
self.required_tier = required_tier
|
||
self.current_tier = current_tier
|
||
|
||
|
||
# ============================================================================
|
||
# PERMISSION DECORATORS
|
||
# ============================================================================
|
||
|
||
def is_admin():
|
||
"""Return an app-command check that verifies administrator or bot-owner status.
|
||
|
||
Raises:
|
||
AdminCheckFailure: If the user is not in a guild or lacks
|
||
administrator permissions.
|
||
"""
|
||
async def predicate(interaction: discord.Interaction):
|
||
if interaction.guild is None or not isinstance(interaction.user, discord.Member):
|
||
raise AdminCheckFailure("You must be in a guild to run this command.")
|
||
|
||
if interaction.user.id == 809619070639013888:
|
||
return True # bot owner
|
||
|
||
if not interaction.user.guild_permissions.administrator:
|
||
raise AdminCheckFailure("You must be an administrator to run this command.")
|
||
return True
|
||
return app_commands.check(predicate)
|
||
|
||
|
||
async def is_dev_team(interaction: discord.Interaction) -> bool:
|
||
"""Check if the user is a dev team member or the bot owner."""
|
||
if interaction.user.id in DEV_DISCORD_IDS:
|
||
return True
|
||
bot = get_bot()
|
||
if bot and await bot.is_owner(interaction.user):
|
||
return True
|
||
return False
|
||
|
||
|
||
def is_blacklisted():
|
||
"""Return an app-command check that rejects blacklisted users or guilds.
|
||
|
||
Both blacklisted users and blacklisted guilds come from the shared
|
||
BLACKLIST.json (see ``shared_store.check_user_blacklist`` and
|
||
``shared_store.blacklisted_guilds``); entries there may include an internal
|
||
comment/name and an optional reason.
|
||
|
||
Raises:
|
||
BlacklistCheckFailure: If the guild or user is blacklisted,
|
||
optionally carrying the reason string.
|
||
"""
|
||
async def predicate(interaction: discord.Interaction):
|
||
guild = interaction.guild
|
||
if guild is not None:
|
||
g_blocked, g_reason = check_guild_blacklist(guild.id)
|
||
if g_blocked:
|
||
raise BlacklistCheckFailure(g_reason, is_guild=True)
|
||
|
||
blocked, reason = check_user_blacklist(interaction.user.id)
|
||
if blocked:
|
||
raise BlacklistCheckFailure(reason)
|
||
return True
|
||
return app_commands.check(predicate)
|
||
|
||
|
||
def gate_entitle(required_tier: str):
|
||
"""Return an app-command check that requires a minimum entitlement tier.
|
||
|
||
Accepts 'standard', 'pro', or 'max'. Guilds with no active entitlement, or
|
||
with a tier strictly lower than ``required_tier``, are rejected with a
|
||
``TierGateFailure`` carrying the required tier so the error handler can
|
||
render the matching localized embed.
|
||
"""
|
||
if required_tier not in TIER_ORDER:
|
||
raise ValueError(f"Unknown tier {required_tier!r}; expected one of {TIER_ORDER}")
|
||
|
||
async def predicate(interaction: discord.Interaction):
|
||
if interaction.guild_id is None:
|
||
raise TierGateFailure(required_tier, None)
|
||
current = await get_guild_tier(interaction.guild_id)
|
||
if _tier_rank(current) < _tier_rank(required_tier):
|
||
raise TierGateFailure(required_tier, current)
|
||
return True
|
||
return app_commands.check(predicate)
|
||
|
||
|
||
# ============================================================================
|
||
# PERMISSION ERROR HANDLER
|
||
# ============================================================================
|
||
|
||
async def permission_fail(interaction: discord.Interaction, error):
|
||
"""Handle permission-related errors with appropriate embeds."""
|
||
lang = await guild_lang(interaction.guild_id) if interaction.guild_id else "en"
|
||
if isinstance(error, BlacklistCheckFailure):
|
||
reason = getattr(error, "reason", None)
|
||
is_guild = getattr(error, "is_guild", False)
|
||
logging.warning(
|
||
"Blacklisted command attempt blocked: user_id=%s guild_id=%s command=%s is_guild=%s reason=%r",
|
||
getattr(interaction.user, "id", None),
|
||
interaction.guild_id,
|
||
interaction.command.qualified_name if interaction.command else None,
|
||
is_guild,
|
||
reason,
|
||
)
|
||
if is_guild:
|
||
if reason:
|
||
desc = t(lang, "common.server_blacklisted_reason", reason=reason)
|
||
else:
|
||
desc = t(lang, "common.access_denied_desc")
|
||
else:
|
||
desc = t(lang, "permission.blacklisted_desc")
|
||
if reason:
|
||
desc += "\n" + t(lang, "permission.reason_line", reason=reason)
|
||
embed = discord.Embed(
|
||
title=t(lang, "permission.blacklisted_title"),
|
||
description=desc,
|
||
color=discord.Color.red()
|
||
)
|
||
elif isinstance(error, AdminCheckFailure):
|
||
embed = discord.Embed(
|
||
title=t(lang, "permission.access_denied_title"),
|
||
description=f"{error.args[0]}",
|
||
color=discord.Color.orange()
|
||
)
|
||
elif isinstance(error, TierGateFailure):
|
||
desc_key = f"permission.tier_gate_{error.required_tier}_desc"
|
||
embed = discord.Embed(
|
||
title=t(lang, "permission.tier_gate_title"),
|
||
description=t(lang, desc_key),
|
||
color=discord.Color.gold()
|
||
)
|
||
elif isinstance(error, app_commands.CheckFailure):
|
||
embed = discord.Embed(
|
||
title=t(lang, "permission.access_denied_title"),
|
||
description=t(lang, "permission.no_permission_desc"),
|
||
color=discord.Color.orange()
|
||
)
|
||
else:
|
||
# fallback for unexpected errors
|
||
embed = discord.Embed(
|
||
title=t(lang, "permission.unexpected_error_title"),
|
||
description=str(error),
|
||
color=discord.Color.dark_red()
|
||
)
|
||
|
||
try:
|
||
await interaction.response.send_message(embed=embed, ephemeral=True)
|
||
except discord.InteractionResponded:
|
||
await interaction.followup.send(embed=embed, ephemeral=True)
|
||
|
||
|
||
# ============================================================================
|
||
# STRING UTILITIES
|
||
# ============================================================================
|
||
|
||
def norm(s: str) -> str:
|
||
"""Normalize a string for comparison (lowercase, strip whitespace)."""
|
||
return (s or "").strip().lower()
|
||
|
||
|
||
def discord_len(s: str) -> int:
|
||
"""
|
||
Discord counts wide/fullwidth CJK characters as 2.
|
||
We mirror that so chunks never exceed the 1024 embed-field limit.
|
||
"""
|
||
total = 0
|
||
for ch in s:
|
||
if unicodedata.east_asian_width(ch) in ("W", "F"): # wide/fullwidth
|
||
total += 2
|
||
else:
|
||
total += 1
|
||
return total
|
||
|
||
|
||
def pad_display_width(text: str, target_width: int) -> str:
|
||
"""
|
||
Pad string so its DISPLAY width matches target_width.
|
||
If it's longer, truncate safely.
|
||
"""
|
||
display_width = wcswidth(text)
|
||
if display_width < 0:
|
||
display_width = len(text)
|
||
|
||
# If string fits, pad
|
||
if display_width < target_width:
|
||
return text + " " * (target_width - display_width)
|
||
|
||
# If it's too long, truncate and pad to target width
|
||
out = ""
|
||
cur_width = 0
|
||
for ch in text:
|
||
w = wcswidth(ch)
|
||
if w < 0: # fallback
|
||
w = 1
|
||
if cur_width + w > target_width:
|
||
break
|
||
out += ch
|
||
cur_width += w
|
||
return out + " " * (target_width - cur_width)
|
||
|
||
|
||
def _row_to_dict(row, cursor) -> Dict[str, Any]:
|
||
"""Convert a sqlite row to a dictionary."""
|
||
if row is None:
|
||
return {}
|
||
return {desc[0]: row[i] for i, desc in enumerate(cursor.description)}
|
||
|
||
|
||
# ============================================================================
|
||
# JSON UTILITIES
|
||
# ============================================================================
|
||
|
||
async def load_json(path: Path, default: Any = None) -> Any:
|
||
"""Load JSON from a file asynchronously, returning default if file doesn't exist."""
|
||
try:
|
||
async with aiofiles.open(path, "r", encoding="utf-8") as f:
|
||
return json.loads(await f.read())
|
||
except FileNotFoundError:
|
||
return default if default is not None else {}
|
||
except Exception as e:
|
||
logging.error(f"Error loading JSON from {path}: {e}")
|
||
return default if default is not None else {}
|
||
|
||
|
||
async def write_json(path: Path, data: Any, indent: int = 4) -> bool:
|
||
"""Write data to a JSON file atomically. Returns True on success."""
|
||
try:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||
async with aiofiles.open(tmp, "w", encoding="utf-8") as f:
|
||
await f.write(json.dumps(data, indent=indent, ensure_ascii=False))
|
||
os.replace(tmp, path)
|
||
return True
|
||
except Exception as e:
|
||
logging.error(f"Error writing JSON to {path}: {e}")
|
||
return False
|
||
|
||
|
||
# ============================================================================
|
||
# PREFERENCES
|
||
# ============================================================================
|
||
|
||
async def load_guild_preferences(guild_id: int) -> Dict[str, Any]:
|
||
"""Load preferences for a guild from STORAGE/PREFERENCES/."""
|
||
prefs_dir = STORAGE_DIR / "PREFERENCES"
|
||
pref_path = prefs_dir / f"{guild_id}-preferences.json"
|
||
return await load_json(pref_path, {})
|
||
|
||
|
||
async def save_guild_preferences(guild_id: int, preferences: Dict[str, Any]) -> bool:
|
||
"""Save preferences for a guild to STORAGE/PREFERENCES/."""
|
||
prefs_dir = STORAGE_DIR / "PREFERENCES"
|
||
prefs_dir.mkdir(parents=True, exist_ok=True)
|
||
pref_path = prefs_dir / f"{guild_id}-preferences.json"
|
||
return await write_json(pref_path, preferences)
|
||
|
||
|
||
async def remove_guild_pref_notification(
|
||
guild_id: int,
|
||
pref_key: str,
|
||
notif_type: str,
|
||
*,
|
||
preferences: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
"""Remove one stored notification route; drop the whole entry if nothing usable remains."""
|
||
prefs = preferences if preferences is not None else await load_guild_preferences(guild_id)
|
||
entry = prefs.get(pref_key)
|
||
if not isinstance(entry, dict) or notif_type not in entry:
|
||
return False
|
||
|
||
entry.pop(notif_type, None)
|
||
|
||
remaining_route_keys = [
|
||
k for k in entry.keys()
|
||
if k not in {"Short", "Long"}
|
||
]
|
||
if not remaining_route_keys:
|
||
prefs.pop(pref_key, None)
|
||
|
||
return await save_guild_preferences(guild_id, prefs)
|
||
|
||
|
||
# ── Tier-aware preference helpers ────────────────────────────────────────────
|
||
|
||
def is_notif_enabled(entry: Any, notif_type: str) -> bool:
|
||
"""True if entry[notif_type] resolves to a real (non-DISABLED) channel ID."""
|
||
if not isinstance(entry, dict):
|
||
return False
|
||
raw = str(entry.get(notif_type, ""))
|
||
if not raw or "DISABLED" in raw.upper():
|
||
return False
|
||
return bool(re.search(r"\d{17,19}", raw))
|
||
|
||
|
||
def is_foreign_pref_entry(entry: Any) -> bool:
|
||
"""True if a preferences entry belongs to another bot (TSSBOT) and SRE should skip it.
|
||
|
||
Both bots share ``STORAGE/PREFERENCES/<guild>-preferences.json``. TSSBOT entries
|
||
carry a ``Type`` of ``tss-team``/``tss-player``; SRE entries have no such Type.
|
||
"""
|
||
return isinstance(entry, dict) and str(entry.get("Type", "")).lower().startswith("tss")
|
||
|
||
|
||
def enabled_pref_keys_for(prefs: Dict[str, Any], notif_type: str) -> List[str]:
|
||
"""Squadron keys (in JSON insertion order) whose entry has this notif enabled."""
|
||
return [
|
||
k for k, v in prefs.items()
|
||
if not is_foreign_pref_entry(v) and is_notif_enabled(v, notif_type)
|
||
]
|
||
|
||
|
||
def allowed_pref_keys_for(prefs: Dict[str, Any], tier: Optional[str], notif_type: str) -> set[str]:
|
||
"""Enabled keys for this notif type that pass the tier cap (first-N slice).
|
||
|
||
- Leaderboard / Global: always uncapped.
|
||
- Wildcards (*, all, everything): dropped entirely on tiers that don't allow them;
|
||
on tiers that do allow them, they do NOT count against the cap.
|
||
- Pre-activation: returns all enabled keys.
|
||
"""
|
||
keys = enabled_pref_keys_for(prefs, notif_type)
|
||
|
||
if not tier_enforcement_active():
|
||
return set(keys)
|
||
|
||
wildcards = [k for k in keys if k.lower() in WILDCARD_KEYS]
|
||
non_wildcards = [k for k in keys if k.lower() not in WILDCARD_KEYS]
|
||
|
||
if not tier_allows_wildcard(tier):
|
||
wildcards = []
|
||
|
||
cap = tier_cap(tier, notif_type)
|
||
if cap is None:
|
||
return set(non_wildcards) | set(wildcards)
|
||
return set(non_wildcards[:cap]) | set(wildcards)
|
||
|
||
|
||
def enabled_non_wildcard_keys_for(prefs: Dict[str, Any], notif_type: str) -> List[str]:
|
||
"""Enabled keys excluding wildcards — used for cap accounting (wildcards don't count)."""
|
||
return [k for k in enabled_pref_keys_for(prefs, notif_type) if k.lower() not in WILDCARD_KEYS]
|
||
|
||
|
||
async def load_features(guild_id: int) -> Dict[str, Any]:
|
||
"""
|
||
Reads STORAGE_DIR / "FEATURES" / "{guild_id}-features.json" and returns its parsed content.
|
||
If the file doesn't exist, creates it with default values and returns those.
|
||
"""
|
||
features_dir = STORAGE_DIR / "FEATURES"
|
||
features_dir.mkdir(parents=True, exist_ok=True)
|
||
features_path = features_dir / f"{guild_id}-features.json"
|
||
|
||
try:
|
||
async with aiofiles.open(features_path, "r", encoding="utf-8") as f:
|
||
return json.loads(await f.read())
|
||
except FileNotFoundError:
|
||
features = {"Translate": "False", "Language": "<English>"}
|
||
try:
|
||
async with aiofiles.open(features_path, "w", encoding="utf-8") as f:
|
||
await f.write(json.dumps(features, indent=4))
|
||
except Exception as e:
|
||
logging.error(f"Error writing default features for guild {guild_id}: {e}")
|
||
return features
|
||
except Exception as e:
|
||
logging.error(f"Error loading features for guild {guild_id}: {e}")
|
||
return {"Translate": "False", "Language": "<English>"}
|
||
|
||
|
||
async def save_features(guild_id: int, features: Dict[str, Any]) -> bool:
|
||
"""
|
||
Serializes `features` (a dict) to JSON and writes it to
|
||
STORAGE_DIR / "FEATURES" / "{guild_id}-features.json", creating directories if needed.
|
||
"""
|
||
features_dir = STORAGE_DIR / "FEATURES"
|
||
features_dir.mkdir(parents=True, exist_ok=True)
|
||
features_path = features_dir / f"{guild_id}-features.json"
|
||
|
||
try:
|
||
async with aiofiles.open(features_path, "w", encoding="utf-8") as f:
|
||
await f.write(json.dumps(features, indent=4))
|
||
return True
|
||
except Exception as e:
|
||
logging.error(f"Error saving features for guild {guild_id}: {e}")
|
||
return False
|
||
|
||
|
||
# ============================================================================
|
||
# SQUADRON RESOLUTION
|
||
# ============================================================================
|
||
|
||
async def resolve_clan(short: Optional[str] = None, tag: Optional[str] = None, long: Optional[str] = None) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Resolve a clan by short name, tag, or long name.
|
||
Returns dict with 'short_name', 'tag_name', 'long_name', 'clan_id'.
|
||
"""
|
||
async with aiosqlite.connect(SQUADRONS_DB_PATH) as db:
|
||
if short:
|
||
cursor = await db.execute(
|
||
"SELECT short_name, tag_name, long_name, clan_id FROM squadrons_data WHERE LOWER(short_name) = ? LIMIT 1",
|
||
(short.lower(),)
|
||
)
|
||
elif tag:
|
||
cursor = await db.execute(
|
||
"SELECT short_name, tag_name, long_name, clan_id FROM squadrons_data WHERE LOWER(tag_name) = ? LIMIT 1",
|
||
(tag.lower(),)
|
||
)
|
||
elif long:
|
||
cursor = await db.execute(
|
||
"SELECT short_name, tag_name, long_name, clan_id FROM squadrons_data WHERE LOWER(long_name) = ? LIMIT 1",
|
||
(long.lower(),)
|
||
)
|
||
else:
|
||
return None
|
||
|
||
row = await cursor.fetchone()
|
||
if row:
|
||
return {
|
||
"short_name": row[0],
|
||
"tag_name": row[1],
|
||
"long_name": row[2],
|
||
"clan_id": int(row[3]) if row[3] is not None else None,
|
||
}
|
||
|
||
# Return unresolved placeholder
|
||
return {
|
||
"short_name": short or tag or "",
|
||
"tag_name": tag or short or "",
|
||
"long_name": long or "<unresolved>",
|
||
"clan_id": None,
|
||
}
|
||
|
||
|
||
async def resolve_clan_id(long_name: str) -> Optional[int]:
|
||
"""Look up the numeric clan_id for a squadron by its long name.
|
||
|
||
Used by commands that hit website endpoints keyed by clan_id (recap
|
||
cards etc.). Returns None if the squadron isn't in squadrons_data.
|
||
"""
|
||
async with aiosqlite.connect(SQUADRONS_DB_PATH) as db:
|
||
cursor = await db.execute(
|
||
"SELECT clan_id FROM squadrons_data WHERE LOWER(long_name) = ? LIMIT 1",
|
||
(long_name.lower(),),
|
||
)
|
||
row = await cursor.fetchone()
|
||
if row and row[0] is not None:
|
||
try:
|
||
return int(row[0])
|
||
except (TypeError, ValueError):
|
||
return None
|
||
return None
|
||
|
||
|
||
async def resolve_pref_key(
|
||
key: str, entry: Optional[Dict[str, Any]] = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""Resolve a preferences-dict key to {clan_id, long_name, short_name, tag_name}.
|
||
|
||
Preferences keys may be:
|
||
- A numeric clan_id (post-migration shape).
|
||
- A long_name (legacy / orphaned entry the migration couldn't resolve).
|
||
- A short_name (rare, but the migration tolerates it).
|
||
|
||
Returns None for unresolvable keys (orphan squadron not in squadrons_data).
|
||
Wildcard keys ("Global", "everything", "all", "*") are NOT this function's
|
||
concern - the caller must handle them upstream.
|
||
"""
|
||
if not key:
|
||
return None
|
||
entry = entry or {}
|
||
key_str = str(key).strip()
|
||
key_lc = key_str.lower()
|
||
|
||
async with aiosqlite.connect(SQUADRONS_DB_PATH) as db:
|
||
if key_str.isdigit():
|
||
cur = await db.execute(
|
||
"SELECT clan_id, long_name, short_name, tag_name "
|
||
"FROM squadrons_data WHERE clan_id = ? LIMIT 1",
|
||
(int(key_str),),
|
||
)
|
||
else:
|
||
cur = await db.execute(
|
||
"SELECT clan_id, long_name, short_name, tag_name FROM squadrons_data "
|
||
"WHERE LOWER(long_name) = ? OR LOWER(short_name) = ? LIMIT 1",
|
||
(key_lc, key_lc),
|
||
)
|
||
row = await cur.fetchone()
|
||
if row:
|
||
return {
|
||
"clan_id": int(row[0]) if row[0] is not None else None,
|
||
"long_name": row[1],
|
||
"short_name": row[2],
|
||
"tag_name": row[3],
|
||
}
|
||
|
||
# Fallback to display-only fields stashed by the migration on the entry itself.
|
||
long_fallback = entry.get("Long") if isinstance(entry, dict) else None
|
||
short_fallback = entry.get("Short") if isinstance(entry, dict) else None
|
||
if long_fallback or short_fallback:
|
||
return {
|
||
"clan_id": int(key_str) if key_str.isdigit() else None,
|
||
"long_name": long_fallback or key_str,
|
||
"short_name": short_fallback or key_str,
|
||
"tag_name": short_fallback or key_str,
|
||
}
|
||
return None
|
||
|
||
|
||
async def resolve_clans(shorts: Optional[List[str]] = None, tags: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||
"""
|
||
Resolve multiple clans by short names or tags.
|
||
Returns list of dicts with 'short_name', 'tag_name', 'long_name'.
|
||
"""
|
||
results = []
|
||
shorts = shorts or []
|
||
tags = tags or []
|
||
|
||
# Process shorts first — only keep successful resolutions (clan_id present).
|
||
# Unresolved placeholders must not block the tag pass below, because some
|
||
# replays store the tagged form (e.g. "-DSPLA-") in the squadron field
|
||
# rather than the bare short name ("DSPLA"), causing the short lookup to
|
||
# miss even though the tag lookup would succeed.
|
||
for short in shorts:
|
||
if short:
|
||
result = await resolve_clan(short=short)
|
||
if result and result.get("clan_id") is not None:
|
||
results.append(result)
|
||
|
||
# Then process tags (for any not already resolved by a successful short lookup)
|
||
resolved_shorts = {r["short_name"].lower() for r in results}
|
||
for tag in tags:
|
||
if tag and tag.lower() not in resolved_shorts:
|
||
result = await resolve_clan(tag=tag)
|
||
if result and result.get("clan_id") is not None:
|
||
results.append(result)
|
||
|
||
return results
|
||
|
||
|
||
async def get_guild_squadron(
|
||
guild_id: int | str | None,
|
||
user_input: str = "",
|
||
) -> Dict[str, str]:
|
||
"""Resolve a squadron for a guild context.
|
||
|
||
If user_input is provided, resolves it via resolve_clan(short=...).
|
||
Otherwise falls back to the guild's default from SQUADRONS.json.
|
||
|
||
Returns dict with 'short_name', 'long_name', 'tag_name'.
|
||
Raises ValueError with user-friendly message if resolution fails.
|
||
"""
|
||
if user_input:
|
||
clan = await resolve_clan(short=user_input)
|
||
if not clan or clan["long_name"] == "<unresolved>":
|
||
raise ValueError(f"Squadron `{user_input}` not found.")
|
||
return clan
|
||
|
||
# Fall back to guild default
|
||
squadrons_path = STORAGE_DIR / "SQUADRONS.json"
|
||
squadrons = await load_json(squadrons_path, {})
|
||
guild_sq = squadrons.get(str(guild_id), {})
|
||
if not guild_sq:
|
||
raise ValueError("No squadron set for this server. Use `/set-squadron` first.")
|
||
|
||
short = guild_sq.get("SQ_ShortHand_Name", "")
|
||
long_name = guild_sq.get("SQ_LongHandName", "")
|
||
if not long_name:
|
||
raise ValueError("No squadron set for this server. Use `/set-squadron` first.")
|
||
|
||
return {"short_name": short, "long_name": long_name, "tag_name": short}
|
||
|
||
|
||
# ============================================================================
|
||
# TIME UTILITIES
|
||
# ============================================================================
|
||
|
||
def get_current_timeslot_start_ts() -> Optional[int]:
|
||
"""Return today's /comp rate-limit window start (epoch) when the current UTC
|
||
time falls inside [posted_start - SQB_COMP_LIMIT_PRE_MIN, posted_end + SQB_COMP_LIMIT_POST_MIN]
|
||
for any posted SQB slot; else None.
|
||
|
||
Outside the window the caller applies no limit. Inside it, the returned
|
||
timestamp scopes usage counting to the whole limit window so one quota
|
||
covers the in-slot window plus the post-close grace period.
|
||
"""
|
||
now_utc = datetime.now(timezone.utc)
|
||
now_t = now_utc.time()
|
||
for _name, posted_start, posted_end in SQB_SLOTS_POSTED:
|
||
limit_start = _shift_time(posted_start, -SQB_COMP_LIMIT_PRE_MIN)
|
||
limit_end = _shift_time(posted_end, SQB_COMP_LIMIT_POST_MIN)
|
||
if limit_start <= now_t <= limit_end:
|
||
window_open = now_utc.replace(
|
||
hour=limit_start.hour,
|
||
minute=limit_start.minute,
|
||
second=0,
|
||
microsecond=0,
|
||
)
|
||
return int(window_open.timestamp())
|
||
return None
|
||
|
||
|
||
def get_most_recent_posted_timeslot_window(
|
||
region: str, now: Optional[datetime] = None, end_grace_minutes: int = 0
|
||
) -> Optional[tuple[int, int]]:
|
||
"""Return the most recently completed posted SQB slot for a region.
|
||
|
||
The returned window is based on `SQB_SLOTS_POSTED`, with an optional end
|
||
grace applied only to the slot end. Result is `(start_ts, end_ts)` in UTC
|
||
epoch seconds. If the current day's slot has not ended yet, fall back to
|
||
the previous day's occurrence for that region.
|
||
"""
|
||
now_utc = now.astimezone(timezone.utc) if now else datetime.now(timezone.utc)
|
||
target = str(region or "").upper()
|
||
|
||
for slot_region, posted_start, posted_end in SQB_SLOTS_POSTED:
|
||
if slot_region.upper() != target:
|
||
continue
|
||
|
||
start_dt = now_utc.replace(
|
||
hour=posted_start.hour,
|
||
minute=posted_start.minute,
|
||
second=0,
|
||
microsecond=0,
|
||
)
|
||
end_dt = now_utc.replace(
|
||
hour=posted_end.hour,
|
||
minute=posted_end.minute,
|
||
second=0,
|
||
microsecond=0,
|
||
)
|
||
|
||
if end_dt <= start_dt:
|
||
end_dt += timedelta(days=1)
|
||
if end_grace_minutes:
|
||
end_dt += timedelta(minutes=end_grace_minutes)
|
||
if now_utc < end_dt:
|
||
start_dt -= timedelta(days=1)
|
||
end_dt -= timedelta(days=1)
|
||
|
||
return int(start_dt.timestamp()), int(end_dt.timestamp())
|
||
|
||
return None
|
||
|
||
|
||
async def get_comp_usage_in_timeslot(guild_id: int, since_ts: int) -> int:
|
||
"""Count /comp invocations for a guild since a given timestamp."""
|
||
try:
|
||
async with aiosqlite.connect(COMMAND_DATA_DB_PATH, timeout=5.0) as db:
|
||
await db.execute("PRAGMA busy_timeout=5000;")
|
||
cur = await db.execute(
|
||
"SELECT COUNT(*) FROM command_usage "
|
||
"WHERE command_name='comp' AND guild_id=? AND timestamp >= ?",
|
||
(str(guild_id), since_ts),
|
||
)
|
||
row = await cur.fetchone()
|
||
return row[0] if row else 0
|
||
except Exception:
|
||
logging.debug("get_comp_usage_in_timeslot query failed", exc_info=True)
|
||
return 0
|
||
|
||
|
||
async def get_comp_usage_in_timeslot_by_user(
|
||
user_id: int, since_ts: int, *, exclude_guild_ids: Iterable[int] = ()
|
||
) -> int:
|
||
"""Count a user's /comp invocations since `since_ts`, excluding any guild
|
||
in `exclude_guild_ids` (used to drop premium/entitled guilds so their use
|
||
doesn't count toward the free per-user cap).
|
||
"""
|
||
try:
|
||
excluded = [str(g) for g in exclude_guild_ids]
|
||
sql = (
|
||
"SELECT COUNT(*) FROM command_usage "
|
||
"WHERE command_name='comp' AND user_id=? AND timestamp >= ?"
|
||
)
|
||
params: list[Any] = [str(user_id), since_ts]
|
||
if excluded:
|
||
placeholders = ",".join("?" for _ in excluded)
|
||
sql += f" AND (guild_id IS NULL OR guild_id NOT IN ({placeholders}))"
|
||
params.extend(excluded)
|
||
async with aiosqlite.connect(COMMAND_DATA_DB_PATH, timeout=5.0) as db:
|
||
await db.execute("PRAGMA busy_timeout=5000;")
|
||
cur = await db.execute(sql, params)
|
||
row = await cur.fetchone()
|
||
return row[0] if row else 0
|
||
except Exception:
|
||
logging.debug("get_comp_usage_in_timeslot_by_user query failed", exc_info=True)
|
||
return 0
|
||
|
||
|
||
def get_entitled_guild_ids() -> list[int]:
|
||
"""Snapshot of currently-entitled (premium) guild IDs from the cache."""
|
||
return list(_entitled_tiers.keys())
|
||
|
||
|
||
def minutes_ago(unix_timestamp: int) -> str:
|
||
"""Convert a unix timestamp to a human-readable 'X minutes ago' string."""
|
||
if not unix_timestamp:
|
||
return "unknown"
|
||
|
||
now = int(time.time())
|
||
diff = now - unix_timestamp
|
||
|
||
if diff < 60:
|
||
return f"{diff}s ago"
|
||
elif diff < 3600:
|
||
return f"{diff // 60}m ago"
|
||
elif diff < 86400:
|
||
return f"{diff // 3600}h ago"
|
||
else:
|
||
return f"{diff // 86400}d ago"
|
||
|
||
|
||
def parse_channel_id(channel_str: str) -> Optional[int]:
|
||
"""Parse a channel ID from various formats like '<#123456>' or '123456'."""
|
||
if not channel_str:
|
||
return None
|
||
|
||
# Remove <# and > if present
|
||
cleaned = channel_str.strip()
|
||
if cleaned.startswith("<#") and cleaned.endswith(">"):
|
||
cleaned = cleaned[2:-1]
|
||
|
||
# Remove DISABLED- prefix if present
|
||
if cleaned.startswith("DISABLED-"):
|
||
cleaned = cleaned[9:]
|
||
|
||
try:
|
||
return int(cleaned)
|
||
except ValueError:
|
||
return None
|
||
|
||
|
||
# ============================================================================
|
||
# REPLAY / LOCAL FORMAT
|
||
# ============================================================================
|
||
|
||
def transform_to_local_format(api_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Transform Spectra API response format to the local replay format used throughout the bot.
|
||
|
||
Expects: {"completed": [<replay_dict>]}
|
||
For WebSocket messages shaped as {"data": {...}, "type": "..."}, wrap first:
|
||
transform_to_local_format({"completed": [msg["data"]]})
|
||
"""
|
||
try:
|
||
if not api_data or "completed" not in api_data or not api_data["completed"]:
|
||
logging.error("Invalid API data structure")
|
||
return None
|
||
|
||
replay = api_data["completed"][0]
|
||
|
||
winner_winged = str(replay.get("winner") or "")
|
||
loser_winged = str(replay.get("loser") or "")
|
||
|
||
def _normalize_squadron_tag(raw: str) -> str:
|
||
s = (raw or "").strip()
|
||
if len(s) >= 3 and not s[0].isalnum() and not s[-1].isalnum():
|
||
s = s[1:-1]
|
||
return s or raw or ""
|
||
|
||
winner_squadron = _normalize_squadron_tag(winner_winged)
|
||
loser_squadron = _normalize_squadron_tag(loser_winged)
|
||
|
||
is_draw = replay.get("draw", False)
|
||
|
||
# Build UID -> player info lookup from players dict
|
||
players_dict = replay.get("players", {})
|
||
uid_lookup = {} # uid_str -> {name, tag_stripped}
|
||
|
||
winner_players: List[Dict[str, Any]] = []
|
||
loser_players: List[Dict[str, Any]] = []
|
||
|
||
for uid_str, pdata in players_dict.items():
|
||
try:
|
||
tag = pdata.get("tag", "")
|
||
tag_stripped = tag[1:-1] if tag else ""
|
||
name = pdata.get("name", "")
|
||
|
||
uid_lookup[uid_str] = {"name": name, "tag_stripped": tag_stripped}
|
||
|
||
# Pick first used unit
|
||
vehicle = "DISCONNECTED"
|
||
for unit_entry in pdata.get("units", []):
|
||
if unit_entry.get("used"):
|
||
vehicle = unit_entry.get("unit", "DISCONNECTED")
|
||
break
|
||
|
||
player_entry = {
|
||
"uid": int(uid_str),
|
||
"nick": name,
|
||
"fake_nick": None,
|
||
"index": int(uid_str),
|
||
"vehicle": vehicle,
|
||
"vehicle_new": "",
|
||
"air_kills": pdata.get("air_kills", 0),
|
||
"ground_kills": pdata.get("ground_kills", 0),
|
||
"assists": pdata.get("assists", 0),
|
||
"deaths": pdata.get("deaths", 0),
|
||
"captures": pdata.get("captures", 0),
|
||
"score": pdata.get("score", 0),
|
||
}
|
||
|
||
# Assign to winner or loser by comparing tag
|
||
if tag == winner_winged:
|
||
winner_players.append(player_entry)
|
||
elif tag == loser_winged:
|
||
loser_players.append(player_entry)
|
||
except (ValueError, TypeError) as e:
|
||
logging.warning(f"Skipping bad player UID {uid_str}: {e}")
|
||
continue
|
||
|
||
# Transform chat entries to formatted strings for downstream regex parsing
|
||
def _fmt_time(ms):
|
||
"""Format milliseconds as MM:SS."""
|
||
total_s = ms // 1000
|
||
return f"{total_s // 60:02d}:{total_s % 60:02d}"
|
||
|
||
chat_log = []
|
||
for chat_entry in replay.get("chat", []):
|
||
uid = str(chat_entry.get("uid", ""))
|
||
scope = chat_entry.get("type", "ALL")
|
||
message = chat_entry.get("message", "")
|
||
time_ms = chat_entry.get("time", 0)
|
||
|
||
info = uid_lookup.get(uid, {"name": "Unknown", "tag_stripped": "???"})
|
||
chat_log.append(
|
||
f"[{_fmt_time(time_ms)}] [{scope}] [{info['tag_stripped']}] `{info['name']}`: {message}"
|
||
)
|
||
|
||
# Decompress events if they arrive as a base85+zstd compressed string
|
||
raw_events = replay.get("events", {})
|
||
if isinstance(raw_events, str):
|
||
try:
|
||
compressed = base64.b85decode(raw_events)
|
||
raw_events = json.loads(zstandard.decompress(compressed).decode("utf-8"))
|
||
except Exception as e:
|
||
logging.error(f"Failed to decompress events: {e}")
|
||
raw_events = {}
|
||
|
||
# Merge damage and kills events, sort by time (ignore awards)
|
||
merged_raw = []
|
||
|
||
for kill in raw_events.get("kills", []):
|
||
merged_raw.append({
|
||
"kind": "kill",
|
||
"time": kill.get("time", 0),
|
||
"offender_uid": str(kill["offender_uid"]) if kill.get("offender_uid") is not None else None,
|
||
"offender_unit": kill.get("offender_unit"),
|
||
"offended_uid": str(kill["offended_uid"]) if kill.get("offended_uid") is not None else None,
|
||
"offended_unit": kill.get("offended_unit"),
|
||
"crashed": kill.get("crashed", False),
|
||
"weapon": kill.get("used_weapon", "") or kill.get("weapon", ""),
|
||
"afire": False,
|
||
})
|
||
|
||
for dmg in raw_events.get("damage", []):
|
||
merged_raw.append({
|
||
"kind": "damage",
|
||
"time": dmg.get("time", 0),
|
||
"offender_uid": str(dmg["offender_uid"]) if dmg.get("offender_uid") is not None else None,
|
||
"offender_unit": dmg.get("offender_unit"),
|
||
"offended_uid": str(dmg["offended_uid"]) if dmg.get("offended_uid") is not None else None,
|
||
"offended_unit": dmg.get("offended_unit"),
|
||
"crashed": False,
|
||
"weapon": "",
|
||
"afire": dmg.get("afire", False),
|
||
})
|
||
|
||
merged_raw.sort(key=lambda e: e.get("time", 0))
|
||
|
||
# Pre-format battle log lines (like chat_log) for direct display
|
||
try:
|
||
_translate = LangTableReader("English")
|
||
except Exception:
|
||
_translate = None
|
||
|
||
def _resolve_vehicle(unit_cdk):
|
||
"""Translate internal vehicle CDK to human-readable name."""
|
||
if not unit_cdk:
|
||
return "Unknown"
|
||
if _translate:
|
||
translated = _translate.get_translate(unit_cdk)
|
||
if translated:
|
||
return apply_vehicle_name_filters(translated)
|
||
return unit_cdk
|
||
|
||
def _resolve_player(uid):
|
||
"""Return (name, clan_tag) for a UID from the uid_lookup dict."""
|
||
if uid is None:
|
||
return "Unknown", ""
|
||
info = uid_lookup.get(str(uid))
|
||
if info:
|
||
return info["name"], info["tag_stripped"]
|
||
return f"Player#{uid}", ""
|
||
|
||
def _team_prefix(sq):
|
||
"""Return '+' for winner, '-' for loser, or ' ' for other squadrons."""
|
||
if sq == winner_squadron:
|
||
return "+"
|
||
elif sq == loser_squadron:
|
||
return "-"
|
||
return " "
|
||
|
||
battle_log = []
|
||
for ev in merged_raw:
|
||
time_str = _fmt_time(ev["time"])
|
||
kind = ev["kind"]
|
||
|
||
if kind == "kill":
|
||
attacker_uid = ev["offender_uid"]
|
||
victim_name, victim_sq = _resolve_player(ev["offended_uid"])
|
||
victim_vehicle = _resolve_vehicle(ev["offended_unit"])
|
||
|
||
if attacker_uid is None or ev["crashed"]:
|
||
prefix = _team_prefix(victim_sq)
|
||
sq_tag = f"[{victim_sq}]"
|
||
battle_log.append(
|
||
f"{prefix}[{time_str}] {sq_tag:<7} {victim_name} ({victim_vehicle}) crashed"
|
||
)
|
||
else:
|
||
name, sq = _resolve_player(attacker_uid)
|
||
vehicle = _resolve_vehicle(ev["offender_unit"])
|
||
prefix = _team_prefix(sq)
|
||
sq_tag = f"[{sq}]"
|
||
battle_log.append(
|
||
f"{prefix}[{time_str}] {sq_tag:<7} {name} ({vehicle}) destroyed {victim_name} ({victim_vehicle})"
|
||
)
|
||
|
||
elif kind == "damage":
|
||
attacker_uid = ev["offender_uid"]
|
||
if attacker_uid is None:
|
||
continue
|
||
name, sq = _resolve_player(attacker_uid)
|
||
vehicle = _resolve_vehicle(ev["offender_unit"])
|
||
victim_name, _ = _resolve_player(ev["offended_uid"])
|
||
victim_vehicle = _resolve_vehicle(ev["offended_unit"])
|
||
afire = "(FIRE) " if ev["afire"] else ""
|
||
prefix = _team_prefix(sq)
|
||
sq_tag = f"[{sq}]"
|
||
battle_log.append(
|
||
f"{prefix}[{time_str}] {sq_tag:<7} {name} ({vehicle}) damaged {afire}{victim_name} ({victim_vehicle})"
|
||
)
|
||
|
||
raw_id = replay.get("_id")
|
||
if raw_id is None:
|
||
raw_id = replay.get("id")
|
||
start_ts = int(replay.get("start_ts") or 0)
|
||
end_ts = int(replay.get("end_ts") or 0)
|
||
|
||
mission_name = str(replay.get("mission_name") or "").strip()
|
||
if not mission_name:
|
||
mission_name = str(replay.get("level_path") or "").strip()
|
||
|
||
mission_mode = str(replay.get("mission_mode") or "").strip()
|
||
if not mission_mode:
|
||
mission_mode = str(replay.get("difficulty") or "").strip()
|
||
|
||
duration = replay.get("duration")
|
||
if duration is None:
|
||
duration = max(0, end_ts - start_ts)
|
||
|
||
session_id_dec = str(raw_id) if raw_id is not None else ""
|
||
try:
|
||
session_id_hex = hex(int(raw_id)).replace("0x", "") if raw_id is not None else ""
|
||
except (ValueError, TypeError):
|
||
session_id_hex = ""
|
||
|
||
return {
|
||
"winning_team_squadron": winner_squadron,
|
||
"losing_team_squadron": loser_squadron,
|
||
"squadrons": [loser_squadron, winner_squadron],
|
||
"squadrons_tagged": [f"{loser_winged}", f"{winner_winged}"],
|
||
"session_id_dec": session_id_dec,
|
||
"session_id_hex": session_id_hex,
|
||
"timestamp": end_ts,
|
||
"start_ts": start_ts,
|
||
"end_ts": end_ts,
|
||
"map": mission_name,
|
||
"mode": mission_mode,
|
||
"duration": duration,
|
||
"draw": is_draw,
|
||
"teams": [
|
||
{
|
||
"team_index": 0,
|
||
"clan_id": "",
|
||
"squadron": winner_squadron,
|
||
"squadron_tagged": f"{winner_winged}",
|
||
"squadron_long": "",
|
||
"players": winner_players,
|
||
},
|
||
{
|
||
"team_index": 1,
|
||
"clan_id": "",
|
||
"squadron": loser_squadron,
|
||
"squadron_tagged": f"{loser_winged}",
|
||
"squadron_long": "",
|
||
"players": loser_players,
|
||
},
|
||
],
|
||
"chat_log": chat_log,
|
||
"battle_log": battle_log,
|
||
"events": raw_events,
|
||
"entities": replay.get("entities", []),
|
||
"level_path": replay.get("level_path"),
|
||
"mission_path": replay.get("mission_path"),
|
||
"difficulty": replay.get("difficulty"),
|
||
"type": replay.get("type", ""),
|
||
}
|
||
|
||
except Exception as e:
|
||
logging.error(f"Failed to transform data: {e}")
|
||
return None
|
||
|
||
|
||
# ============================================================================
|
||
# (VEHICLE CACHE)
|
||
# ============================================================================
|
||
|
||
# Global caches for vehicle data (populated on demand)
|
||
game_data_cache: Optional[List] = None
|
||
game_data_cache_all: Optional[List] = None
|
||
|
||
def _augmented_tags(unit_tags: UnitTags, cdk: str) -> dict:
|
||
"""Return the cached tags dict for *cdk*, with a few derived tags folded in.
|
||
|
||
``unittags.blk`` stores some classification info on the entry's top-level
|
||
``type`` field rather than inside ``tags`` (notably ``type=helicopter`` —
|
||
the per-helicopter-class tags like ``type_attack_helicopter`` are present,
|
||
but not the generic ``type_helicopter``). The bot's ``data_parser._get_tags``
|
||
derives those at lookup time; we mirror that here so the on-disk cache
|
||
consumed by the website classifies vehicles the same way.
|
||
"""
|
||
entry = unit_tags.raw.get(cdk, {}) or {}
|
||
tags = dict(entry.get("tags") or {})
|
||
type_field = entry.get("type")
|
||
if type_field:
|
||
tags.setdefault(type_field, True)
|
||
if type_field == "helicopter":
|
||
tags.setdefault("type_helicopter", True)
|
||
return tags
|
||
|
||
|
||
async def init_game_cache():
|
||
"""Initialize the vehicle data cache (only vehicles with icons)."""
|
||
global game_data_cache
|
||
|
||
unit_tags = UnitTags.get()
|
||
all_names = unit_tags.all_names
|
||
logging.info(f"[GAMES] TOTAL VEHICLES: {len(all_names)}")
|
||
|
||
icons_dir = SHARED_DIR / "ICONS" / "VEHICLES"
|
||
# Case-insensitive lookup: unittags.blk uses CDKs like "ussr_su_122P" while the
|
||
# icon file on disk is "ussr_su_122p.png". On case-sensitive filesystems an exact
|
||
# match silently drops these vehicles from the cache.
|
||
icons_on_disk = {p.name.lower(): p.name for p in icons_dir.iterdir() if p.suffix == ".png"}
|
||
present = [(cdk, icons_on_disk[f"{cdk}.png".lower()]) for cdk in all_names if f"{cdk}.png".lower() in icons_on_disk]
|
||
logging.info(f"[GAMES] ICON PAIRS FOUND: {len(present)}")
|
||
|
||
translate = LangTableReader("English")
|
||
|
||
cache = []
|
||
for cdk, icon in present:
|
||
raw = translate.get_translate(cdk)
|
||
human = normalize_name(raw) if raw else cdk
|
||
misc_params = _augmented_tags(unit_tags, cdk)
|
||
cache.append([cdk, human, icon, misc_params])
|
||
|
||
out_path = CACHE_DIR / "vehicle_data_cache.json"
|
||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||
|
||
game_data_cache = cache
|
||
return game_data_cache
|
||
|
||
|
||
async def init_game_cache_all():
|
||
"""Initialize the complete vehicle data cache (all vehicles)."""
|
||
global game_data_cache_all
|
||
|
||
unit_tags = UnitTags.get()
|
||
all_names = unit_tags.all_names
|
||
logging.info(f"[GAMES] TOTAL VEHICLES (ALL): {len(all_names)}")
|
||
|
||
icons_dir = SHARED_DIR / "ICONS" / "VEHICLES"
|
||
icons_on_disk = {p.name.lower(): p.name for p in icons_dir.iterdir() if p.suffix == ".png"} if icons_dir.is_dir() else {}
|
||
|
||
translate = LangTableReader("English")
|
||
|
||
cache = []
|
||
for cdk in all_names:
|
||
raw = translate.get_translate(cdk)
|
||
if raw is None:
|
||
raw = cdk
|
||
|
||
human = normalize_name(raw)
|
||
icon = icons_on_disk.get(f"{cdk}.png".lower(), f"{cdk}.png")
|
||
misc_params = _augmented_tags(unit_tags, cdk)
|
||
cache.append([cdk, human, icon, misc_params])
|
||
|
||
out_path = CACHE_DIR / "vehicle_data_cache_all.json"
|
||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
json.dump(cache, f, ensure_ascii=False, indent=2)
|
||
|
||
game_data_cache_all = cache
|
||
return game_data_cache_all
|
||
|
||
|
||
# Languages exposed by the website (matches web/locales/*.json). Maps the site
|
||
# locale code to the column name used inside lang/units.csv. The columns are
|
||
# stored with literal angle brackets (e.g. ``<English>``) — without them
|
||
# ``update_language()`` silently fails its ``in header_info`` check and every
|
||
# reader falls back to column 0 (which happens to be English, masking the bug).
|
||
_WEB_LANG_TO_LANG_COLUMN = {
|
||
"en": "<English>",
|
||
"ru": "<Russian>",
|
||
"fr": "<French>",
|
||
"it": "<Italian>",
|
||
"uk": "<Ukrainian>",
|
||
"de": "<German>",
|
||
"es": "<Spanish>",
|
||
"pl": "<Polish>",
|
||
"cs": "<Czech>",
|
||
"zh-CN": "<Chinese>",
|
||
}
|
||
|
||
|
||
async def init_vehicle_translation_cache():
|
||
"""Generate ``vehicle_translations.json`` with localized names per vehicle.
|
||
|
||
Output schema (served by the Node API at ``/api/i18n/vehicles``):
|
||
|
||
{ "ussr_t_34_85": { "en": "T-34-85", "ru": "Т-34-85", ... }, ... }
|
||
|
||
Uses the same lang.vromfs.bin source the Discord bot reads when it
|
||
translates scoreboards on /language switch.
|
||
"""
|
||
unit_tags = UnitTags.get()
|
||
all_names = unit_tags.all_names
|
||
|
||
available_columns = list(LangTableReader.header_info)
|
||
logging.info(f"[I18N] Vehicle translations: {len(all_names)} vehicles, lang columns available: {available_columns}")
|
||
|
||
translators: dict[str, LangTableReader] = {}
|
||
for site_code, lang_column in _WEB_LANG_TO_LANG_COLUMN.items():
|
||
reader = LangTableReader(lang_column)
|
||
if not reader.update_language(lang_column):
|
||
logging.warning(
|
||
f"[I18N] Vehicle translations: lang column '{lang_column}' not found for site lang '{site_code}', "
|
||
f"available columns: {available_columns}"
|
||
)
|
||
continue
|
||
translators[site_code] = reader
|
||
|
||
if not translators:
|
||
logging.error(
|
||
"[I18N] Vehicle translations: no usable language columns matched. "
|
||
f"Wanted {list(_WEB_LANG_TO_LANG_COLUMN.values())}, got {available_columns}. "
|
||
"Output will be an empty {} until the column-name map is corrected."
|
||
)
|
||
|
||
out: dict[str, dict[str, str]] = {}
|
||
for cdk in all_names:
|
||
names: dict[str, str] = {}
|
||
for site_code, reader in translators.items():
|
||
raw = reader.get_translate(cdk)
|
||
if raw is None:
|
||
continue
|
||
# Keep visible decoration glyphs (▄ ◢ ◊ etc.) so country / event
|
||
# indicators survive to the website. Discord-side renderers still
|
||
# call apply_vehicle_name_filters() with default strip_decorations=True.
|
||
cleaned = apply_vehicle_name_filters(raw, strip_decorations=False)
|
||
if cleaned:
|
||
names[site_code] = cleaned
|
||
if names:
|
||
out[cdk] = names
|
||
|
||
out_path = CACHE_DIR / "vehicle_translations.json"
|
||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
with open(out_path, "w", encoding="utf-8") as f:
|
||
json.dump(out, f, ensure_ascii=False)
|
||
|
||
logging.info(f"[I18N] Wrote vehicle translations for {len(out)} vehicles in {len(translators)} languages → {out_path}")
|
||
return out
|
||
|
||
|
||
# ============================================================================
|
||
# COMMAND STATS
|
||
# ============================================================================
|
||
|
||
_CMD_STATS_INIT_DONE = False
|
||
|
||
|
||
async def init_command_stats_db() -> None:
|
||
"""Create the command_usage table and indexes if they don't exist."""
|
||
global _CMD_STATS_INIT_DONE
|
||
if _CMD_STATS_INIT_DONE:
|
||
return
|
||
|
||
async with aiosqlite.connect(COMMAND_DATA_DB_PATH, timeout=10.0) as db:
|
||
await db.execute("PRAGMA journal_mode=WAL;")
|
||
await db.execute("PRAGMA synchronous=NORMAL;")
|
||
await db.execute("PRAGMA busy_timeout=5000;")
|
||
await db.execute("""
|
||
CREATE TABLE IF NOT EXISTS command_usage (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
command_name TEXT NOT NULL,
|
||
user_id TEXT NOT NULL,
|
||
guild_id TEXT,
|
||
channel_id TEXT,
|
||
timestamp INTEGER NOT NULL
|
||
)
|
||
""")
|
||
await db.execute("CREATE INDEX IF NOT EXISTS idx_cmd_usage_command ON command_usage(command_name)")
|
||
await db.execute("CREATE INDEX IF NOT EXISTS idx_cmd_usage_timestamp ON command_usage(timestamp)")
|
||
await db.execute("CREATE INDEX IF NOT EXISTS idx_cmd_usage_guild ON command_usage(guild_id)")
|
||
await db.commit()
|
||
|
||
_CMD_STATS_INIT_DONE = True
|
||
logging.info("Initialized COMMAND_DATA.db")
|
||
|
||
|
||
async def collect_command_stats(interaction) -> None:
|
||
"""Record a single command invocation. Never raises."""
|
||
try:
|
||
cmd_name = interaction.command.name if interaction.command else "unknown"
|
||
async with aiosqlite.connect(COMMAND_DATA_DB_PATH, timeout=5.0) as db:
|
||
await db.execute("PRAGMA busy_timeout=5000;")
|
||
await db.execute(
|
||
"INSERT INTO command_usage (command_name, user_id, guild_id, channel_id, timestamp) "
|
||
"VALUES (?, ?, ?, ?, ?)",
|
||
(
|
||
cmd_name,
|
||
str(interaction.user.id),
|
||
str(interaction.guild_id) if interaction.guild_id else None,
|
||
str(interaction.channel_id) if interaction.channel_id else None,
|
||
int(time.time()),
|
||
),
|
||
)
|
||
await db.commit()
|
||
except Exception:
|
||
logging.debug("command_stats insert failed", exc_info=True)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# i18n — locale lookups backed by BOT/locales/*.json
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_LOCALES_DIR = Path(__file__).parent / "locales"
|
||
|
||
_LANG_MAP = {
|
||
"<English>": "en",
|
||
"<French>": "fr",
|
||
"<Italian>": "it",
|
||
"<German>": "de",
|
||
"<Spanish>": "es",
|
||
"<Russian>": "ru",
|
||
"<Polish>": "pl",
|
||
"<Czech>": "cs",
|
||
"<Chinese>": "zh-CN",
|
||
"<Portuguese>": "pt",
|
||
"<Ukrainian>": "uk",
|
||
}
|
||
|
||
_locales: Dict[str, dict] = {}
|
||
|
||
|
||
def _load_locales() -> None:
|
||
_locales.clear()
|
||
for f in _LOCALES_DIR.glob("*.json"):
|
||
with open(f, "r", encoding="utf-8") as fh:
|
||
_locales[f.stem] = json.load(fh)
|
||
|
||
|
||
def t(lang: str, key: str, **kwargs) -> str:
|
||
"""Look up a translation by dotted key, with English fallback."""
|
||
code = _LANG_MAP.get(lang, lang) if lang.startswith("<") else lang
|
||
if code not in _locales:
|
||
code = "en"
|
||
|
||
val: Any = _locales.get(code)
|
||
for part in key.split("."):
|
||
if isinstance(val, dict):
|
||
val = val.get(part)
|
||
else:
|
||
val = None
|
||
break
|
||
|
||
if val is None and code != "en":
|
||
val = _locales.get("en")
|
||
for part in key.split("."):
|
||
if isinstance(val, dict):
|
||
val = val.get(part)
|
||
else:
|
||
val = None
|
||
break
|
||
|
||
if val is None:
|
||
return key
|
||
|
||
if kwargs:
|
||
try:
|
||
return str(val).format(**kwargs)
|
||
except (KeyError, IndexError):
|
||
return str(val)
|
||
return str(val)
|
||
|
||
|
||
_DISCORD_LOCALE_TO_LANG = {
|
||
"cs": "cs",
|
||
"de": "de",
|
||
"es-ES": "es",
|
||
"fr": "fr",
|
||
"it": "it",
|
||
"pl": "pl",
|
||
"pt-BR": "pt",
|
||
"ru": "ru",
|
||
"uk": "uk",
|
||
"zh-CN": "zh-CN",
|
||
}
|
||
|
||
|
||
def command_locale(default: str, key: str) -> app_commands.locale_str:
|
||
"""Mark command metadata for Discord app-command localization."""
|
||
return app_commands.locale_str(default, key=key)
|
||
|
||
|
||
class LocaleJsonTranslator(app_commands.Translator):
|
||
"""Discord app-command translator backed by BOT/locales/*.json."""
|
||
|
||
async def translate(
|
||
self,
|
||
string: app_commands.locale_str,
|
||
locale: discord.Locale,
|
||
context: app_commands.TranslationContext,
|
||
) -> Optional[str]:
|
||
key = string.extras.get("key")
|
||
if not key:
|
||
return None
|
||
lang = _DISCORD_LOCALE_TO_LANG.get(str(locale))
|
||
if not lang:
|
||
return None
|
||
translated = t(lang, key)
|
||
if translated == key or translated == string.message:
|
||
return None
|
||
return translated[:100]
|
||
|
||
|
||
def lang_from_features(features: dict) -> str:
|
||
"""Extract locale code from a guild's features dict."""
|
||
stored = features.get("Language", "<English>")
|
||
return _LANG_MAP.get(stored, "en")
|
||
|
||
|
||
async def guild_lang(guild_id: int) -> str:
|
||
"""Get the locale code for a guild by reading its features file."""
|
||
features = await load_features(guild_id)
|
||
return lang_from_features(features)
|
||
|
||
|
||
# ============================================================================
|
||
# SEASON SCHEDULE (shared with web/utils/seasons.js)
|
||
# ============================================================================
|
||
|
||
SEASONS_FILE = Path(__file__).resolve().parent.parent / "web" / "constants" / "seasons"
|
||
|
||
_SEASON_HEADER_RE = re.compile(r"^\s*(\d{4})-(I{1,3}|IV|VI{0,3}|IX|X)\s*$")
|
||
_WEEK_LINE_RE = re.compile(
|
||
r"^week\s+\d+\s+\(\d{2}\.\d{2}\s+[—-]\s+\d{2}\.\d{2}\)\s+<t:(\d+):R>"
|
||
)
|
||
_EOS_LINE_RE = re.compile(
|
||
r"^until\s+eos\s+\(\d{2}\.\d{2}\s+[—-]\s+(\d{2})\.(\d{2})\)\s+<t:(\d+):R>"
|
||
)
|
||
|
||
SEASON_NAME_RE = re.compile(r"^\d{4}-(I{1,3}|IV|VI{0,3}|IX|X)$")
|
||
|
||
|
||
class SeasonRange(TypedDict):
|
||
start: int
|
||
end: int
|
||
status: str # "in_progress" or "completed"
|
||
|
||
|
||
def _end_of_day_utc(year: int, day: int, month: int) -> int:
|
||
return int(datetime(year, month, day, 23, 59, 59, tzinfo=timezone.utc).timestamp())
|
||
|
||
|
||
def _parse_seasons_file(content: str) -> Dict[str, Dict[str, int]]:
|
||
"""Parse into {name: {start, end}} without computed status."""
|
||
out: Dict[str, Dict[str, int]] = {}
|
||
current: Optional[str] = None
|
||
first_ts: Optional[int] = None
|
||
eos_day: Optional[int] = None
|
||
eos_month: Optional[int] = None
|
||
|
||
def commit() -> None:
|
||
nonlocal current, first_ts, eos_day, eos_month
|
||
if current and first_ts is not None and eos_day is not None and eos_month is not None:
|
||
year = int(current.split("-")[0])
|
||
out[current] = {
|
||
"start": first_ts,
|
||
"end": _end_of_day_utc(year, eos_day, eos_month),
|
||
}
|
||
current = None
|
||
first_ts = None
|
||
eos_day = None
|
||
eos_month = None
|
||
|
||
for line in content.split("\n"):
|
||
h = _SEASON_HEADER_RE.match(line)
|
||
if h:
|
||
commit()
|
||
current = f"{h.group(1)}-{h.group(2)}"
|
||
continue
|
||
if not current:
|
||
continue
|
||
w = _WEEK_LINE_RE.match(line)
|
||
if w and first_ts is None:
|
||
first_ts = int(w.group(1))
|
||
continue
|
||
e = _EOS_LINE_RE.match(line)
|
||
if e:
|
||
eos_day = int(e.group(1))
|
||
eos_month = int(e.group(2))
|
||
commit()
|
||
return out
|
||
|
||
|
||
_seasons_cached_content: Optional[str] = None
|
||
_seasons_cached_parsed: Optional[Dict[str, Dict[str, int]]] = None
|
||
|
||
|
||
def _load_seasons() -> tuple[str, Dict[str, Dict[str, int]]]:
|
||
global _seasons_cached_content, _seasons_cached_parsed
|
||
if _seasons_cached_content is None or _seasons_cached_parsed is None:
|
||
_seasons_cached_content = SEASONS_FILE.read_text(encoding="utf-8")
|
||
_seasons_cached_parsed = _parse_seasons_file(_seasons_cached_content)
|
||
return _seasons_cached_content, _seasons_cached_parsed
|
||
|
||
|
||
def get_seasons() -> Dict[str, SeasonRange]:
|
||
"""Return all seasons with computed status based on current time."""
|
||
_, parsed = _load_seasons()
|
||
now = int(time.time())
|
||
out: Dict[str, SeasonRange] = {}
|
||
for name, rng in parsed.items():
|
||
out[name] = {
|
||
"start": rng["start"],
|
||
"end": rng["end"],
|
||
"status": "completed" if rng["end"] < now else "in_progress",
|
||
}
|
||
return out
|
||
|
||
|
||
def get_season_range(name: str) -> Optional[SeasonRange]:
|
||
return get_seasons().get(name)
|
||
|
||
|
||
def get_week_boundaries(name: str) -> List[int]:
|
||
"""All `week N (…) <t:TS:R>` timestamps for the given season, in order."""
|
||
content, _ = _load_seasons()
|
||
out: List[int] = []
|
||
in_season = False
|
||
week_ts = re.compile(r"<t:(\d+):R>")
|
||
for line in content.split("\n"):
|
||
h = _SEASON_HEADER_RE.match(line)
|
||
if h:
|
||
in_season = f"{h.group(1)}-{h.group(2)}" == name
|
||
continue
|
||
if not in_season:
|
||
continue
|
||
m = week_ts.search(line)
|
||
if m:
|
||
out.append(int(m.group(1)))
|
||
return out
|
||
|
||
|
||
# ============================================================================
|
||
# SEASON RECAP CARDS (shared cache with web/server.js)
|
||
# ============================================================================
|
||
|
||
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
||
SQUADRON_RECAP_CACHE_DIR = STORAGE_DIR / "RECAPS" / "squadrons"
|
||
PLAYER_RECAP_CACHE_DIR = STORAGE_DIR / "RECAPS" / "players"
|
||
|
||
_RECAP_PYTHON_BIN = _REPO_ROOT.parent / "SHARED" / ".venv" / "bin" / "python"
|
||
_RECAP_SCRIPT = _REPO_ROOT / "BOT" / "render_recap.py"
|
||
|
||
RECAP_TTL_SECONDS = 24 * 60 * 60 # in-progress season TTL (matches web)
|
||
RECAP_RENDER_TIMEOUT_SECONDS = 30
|
||
|
||
RECAP_THEMES = {"light", "dark"}
|
||
RECAP_DEFAULT_THEME = "dark"
|
||
# Languages the recap renderer supports (must match web/server.js RECAP_LANGS).
|
||
RECAP_LANGS = {"cs", "de", "en", "es", "fr", "it", "pl", "ru", "uk", "zh-CN"}
|
||
RECAP_DEFAULT_LANG = "en"
|
||
|
||
|
||
class RecapError(Exception):
|
||
"""Raised when a recap card cannot be produced."""
|
||
|
||
|
||
def normalize_recap_theme(theme: str) -> str:
|
||
return theme if theme in RECAP_THEMES else RECAP_DEFAULT_THEME
|
||
|
||
|
||
def normalize_recap_lang(lang: str) -> str:
|
||
return lang if lang in RECAP_LANGS else RECAP_DEFAULT_LANG
|
||
|
||
|
||
def _recap_cleanup_tmp(out_path: Path) -> None:
|
||
try:
|
||
(out_path.with_suffix(out_path.suffix + ".tmp")).unlink(missing_ok=True)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
async def _spawn_recap_render(
|
||
mode: Literal["squadron", "player"],
|
||
*,
|
||
clan_id: Optional[int],
|
||
uid: Optional[int],
|
||
season: str,
|
||
season_start: int,
|
||
season_end: int,
|
||
week_boundaries: List[int],
|
||
out_path: Path,
|
||
theme: str,
|
||
lang: str,
|
||
) -> None:
|
||
args: List[str] = [
|
||
str(_RECAP_SCRIPT),
|
||
"--mode", mode,
|
||
"--season", season,
|
||
"--season-start", str(season_start),
|
||
"--season-end", str(season_end),
|
||
"--week-boundaries", ",".join(str(ts) for ts in week_boundaries),
|
||
"--theme", theme,
|
||
"--lang", lang,
|
||
"--out", str(out_path),
|
||
]
|
||
if mode == "squadron":
|
||
if clan_id is None:
|
||
raise RecapError("clan_id required for squadron recap")
|
||
args += ["--clan-id", str(clan_id)]
|
||
else:
|
||
if uid is None:
|
||
raise RecapError("uid required for player recap")
|
||
args += ["--uid", str(uid)]
|
||
|
||
proc = await asyncio.create_subprocess_exec(
|
||
str(_RECAP_PYTHON_BIN),
|
||
*args,
|
||
cwd=str(_REPO_ROOT),
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
try:
|
||
_stdout, stderr = await asyncio.wait_for(
|
||
proc.communicate(), timeout=RECAP_RENDER_TIMEOUT_SECONDS
|
||
)
|
||
except asyncio.TimeoutError:
|
||
proc.kill()
|
||
await proc.wait()
|
||
_recap_cleanup_tmp(out_path)
|
||
raise RecapError(f"recap render timed out after {RECAP_RENDER_TIMEOUT_SECONDS}s")
|
||
|
||
if proc.returncode != 0:
|
||
err = (stderr or b"").decode("utf-8", errors="replace")[:2000]
|
||
_recap_cleanup_tmp(out_path)
|
||
logging.error(
|
||
"(RECAP) render failed mode=%s season=%s theme=%s lang=%s rc=%s stderr=%s",
|
||
mode, season, theme, lang, proc.returncode, err,
|
||
)
|
||
raise RecapError("recap render failed")
|
||
|
||
|
||
async def _get_recap(
|
||
mode: Literal["squadron", "player"],
|
||
*,
|
||
clan_id: Optional[int],
|
||
uid: Optional[int],
|
||
season: str,
|
||
theme: str,
|
||
lang: str,
|
||
) -> Path:
|
||
theme = normalize_recap_theme(theme)
|
||
lang = normalize_recap_lang(lang)
|
||
|
||
season_range = get_season_range(season)
|
||
if not season_range:
|
||
raise RecapError(f"unknown season: {season}")
|
||
|
||
if mode == "squadron":
|
||
assert clan_id is not None
|
||
cache_dir = SQUADRON_RECAP_CACHE_DIR / season
|
||
cache_path = cache_dir / f"{clan_id}-{theme}-{lang}.png"
|
||
else:
|
||
assert uid is not None
|
||
cache_dir = PLAYER_RECAP_CACHE_DIR / season
|
||
cache_path = cache_dir / f"{uid}-{theme}-{lang}.png"
|
||
|
||
serve_from_cache = False
|
||
try:
|
||
stat = cache_path.stat()
|
||
if season_range["status"] == "completed":
|
||
serve_from_cache = True
|
||
elif (time.time() - stat.st_mtime) < RECAP_TTL_SECONDS:
|
||
serve_from_cache = True
|
||
except FileNotFoundError:
|
||
pass
|
||
|
||
if not serve_from_cache:
|
||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||
await _spawn_recap_render(
|
||
mode,
|
||
clan_id=clan_id,
|
||
uid=uid,
|
||
season=season,
|
||
season_start=season_range["start"],
|
||
season_end=season_range["end"],
|
||
week_boundaries=get_week_boundaries(season),
|
||
out_path=cache_path,
|
||
theme=theme,
|
||
lang=lang,
|
||
)
|
||
|
||
if not cache_path.exists():
|
||
raise RecapError("recap file missing after render")
|
||
|
||
return cache_path
|
||
|
||
|
||
async def get_squadron_recap(
|
||
clan_id: int, season: str, theme: str, lang: str
|
||
) -> Path:
|
||
return await _get_recap(
|
||
"squadron", clan_id=clan_id, uid=None,
|
||
season=season, theme=theme, lang=lang,
|
||
)
|
||
|
||
|
||
async def get_player_recap(
|
||
uid: int, season: str, theme: str, lang: str
|
||
) -> Path:
|
||
return await _get_recap(
|
||
"player", clan_id=None, uid=uid,
|
||
season=season, theme=theme, lang=lang,
|
||
)
|
||
|
||
|
||
_load_locales()
|