from __future__ import annotations import hashlib import os from dataclasses import dataclass from typing import Mapping from sqlalchemy import select from ecospecies_api.db import SessionLocal, create_db_engine from ecospecies_api.models import Base, ContributorAccount ROLE_ORDER = { "viewer": 1, "contributor": 2, "editor": 3, "admin": 4, } @dataclass(frozen=True) class AuthSession: token: str username: str role: str def _normalize_role(role: str) -> str: normalized = role.strip().lower() if normalized not in ROLE_ORDER: raise ValueError(f"Unsupported role: {role}") return normalized def _parse_token_entry(entry: str) -> tuple[str, AuthSession]: parts = [part.strip() for part in entry.split(":")] if len(parts) != 3: raise ValueError( "ECOSPECIES_AUTH_TOKENS entries must use the format token:username:role" ) token, username, role = parts if not token or not username: raise ValueError("Auth token and username must be non-empty") return token, AuthSession(token=token, username=username, role=_normalize_role(role)) def get_token_registry() -> dict[str, AuthSession]: registry: dict[str, AuthSession] = {} configured = os.environ.get("ECOSPECIES_AUTH_TOKENS", "").strip() if configured: for raw_entry in configured.split(","): entry = raw_entry.strip() if not entry: continue token, session = _parse_token_entry(entry) registry[token] = session engine = create_db_engine() Base.metadata.create_all(engine) with SessionLocal() as session: for account in session.scalars( select(ContributorAccount).where(ContributorAccount.is_active.is_(True)) ): registry[account.token_hash] = AuthSession( token=account.token_hash, username=account.email, role="contributor", ) return registry def get_bearer_token(headers: Mapping[str, str]) -> str | None: auth_header = headers.get("Authorization", "").strip() if auth_header.lower().startswith("bearer "): token = auth_header[7:].strip() return token or None token = headers.get("X-EcoSpecies-Token", "").strip() return token or None def resolve_auth_session(headers: Mapping[str, str]) -> AuthSession | None: registry = get_token_registry() token = get_bearer_token(headers) if not token: return None direct = registry.get(token) if direct is not None: return direct token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() return registry.get(token_hash) def auth_is_configured() -> bool: return bool(get_token_registry()) def role_satisfies(role: str, required_role: str) -> bool: return ROLE_ORDER[_normalize_role(role)] >= ROLE_ORDER[_normalize_role(required_role)]