from __future__ import annotations import asyncio import time import uuid from typing import Any, Dict, List from fastapi import APIRouter, Request, Depends from fastapi.responses import JSONResponse, StreamingResponse from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry from rolemesh_gateway.upstream import UpstreamClient, UpstreamError from rolemesh_gateway.auth import require_client_auth, require_node_auth router = APIRouter() def _openai_error(message: str, code: str = "upstream_error", status_code: int = 502) -> JSONResponse: return JSONResponse( status_code=status_code, content={ "error": { "message": message, "type": "gateway_error", "param": None, "code": code, } }, ) async def _model_status( alias: str, entry: ProxyModel | DiscoveredModel, registry: Registry, upstream: UpstreamClient, ) -> Dict[str, Any]: if isinstance(entry, ProxyModel): base_url = str(entry.proxy_url).rstrip("/") try: await upstream.get_models(base_url) return {"alias": alias, "available": True, "base_url": base_url} except UpstreamError as exc: return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)} matching_nodes = registry.nodes_for_role(entry.role, include_stale=False) if not matching_nodes: stale_nodes = registry.nodes_for_role(entry.role, include_stale=True) error = "no_registered_nodes" if stale_nodes: error = "no_fresh_registered_nodes" return {"alias": alias, "available": False, "role": entry.role, "error": error} for node in matching_nodes: base_url = str(node.base_url).rstrip("/") try: await upstream.get_models(base_url) return { "alias": alias, "available": True, "role": entry.role, "base_url": base_url, "node_id": node.node_id, } except UpstreamError: continue return { "alias": alias, "available": False, "role": entry.role, "error": "no_healthy_registered_nodes", } async def _collect_model_statuses( cfg: Config, registry: Registry, upstream: UpstreamClient, ) -> Dict[str, Dict[str, Any]]: statuses = await asyncio.gather( *[ _model_status(alias, entry, registry, upstream) for alias, entry in cfg.models.items() ] ) return {status["alias"]: status for status in statuses} @router.get("/v1/models") async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]: cfg: Config = request.app.state.cfg registry: Registry = request.app.state.registry upstream: UpstreamClient = request.app.state.upstream statuses = await _collect_model_statuses(cfg, registry, upstream) data = [] for name, entry in cfg.models.items(): status = statuses[name] if not status["available"]: continue item = { "id": entry.openai_model_name, "object": "model", "owned_by": "local", } if isinstance(entry, ProxyModel): item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True} else: item["rolemesh"] = { "type": "discovered", "role": entry.role, "strategy": entry.strategy, "available": True, } data.append(item) # Expose currently registered nodes (informational) nodes = [ n.model_dump(mode="json") | {"stale": registry.is_stale(n)} for n in registry.list_nodes(include_stale=True) ] unavailable = [status for status in statuses.values() if not status["available"]] return { "object": "list", "data": data, "rolemesh": { "registered_nodes": nodes, "unavailable_models": unavailable, }, } @router.post("/v1/chat/completions") async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any: cfg: Config = request.app.state.cfg upstream: UpstreamClient = request.app.state.upstream registry: Registry = request.app.state.registry req_id = str(uuid.uuid4()) started = time.time() body = await request.json() model = body.get("model") or cfg.default_model stream = bool(body.get("stream", False)) if model not in cfg.models: return _openai_error(f"Unknown model '{model}'. Check GET /v1/models.", code="unknown_model", status_code=400) entry = cfg.models[model] # Apply per-model defaults if request didn't specify those keys. defaults = entry.defaults if hasattr(entry, "defaults") else {} for k, v in defaults.items(): body.setdefault(k, v) # Resolve upstream base URL if isinstance(entry, ProxyModel): base_url = str(entry.proxy_url).rstrip("/") elif isinstance(entry, DiscoveredModel): node = registry.pick_node_for_role(entry.role, strategy=entry.strategy) if not node: return _openai_error( f"No registered nodes available for role '{entry.role}'. " f"Register a node via POST /v1/nodes/register, or use proxy mode.", code="no_upstream", status_code=503, ) base_url = str(node.base_url).rstrip("/") else: return _openai_error("Invalid model configuration.", code="bad_config", status_code=500) # Proxy request try: if not stream: out = await upstream.chat_completions(base_url, body) request.app.logger.info( "chat_completions ok", extra={"req_id": req_id, "model": model, "upstream": base_url, "ms": int(1000*(time.time()-started))}, ) return out # streaming: passthrough bytes (SSE) async def gen(): async for chunk in upstream.chat_completions_stream(base_url, body): yield chunk return StreamingResponse(gen(), media_type="text/event-stream") except UpstreamError as e: request.app.logger.warning( "chat_completions upstream error", extra={"req_id": req_id, "model": model, "upstream": base_url, "err": str(e)}, ) return _openai_error(str(e), code="upstream_error", status_code=e.status_code or 502) @router.get("/health") async def health() -> Dict[str, str]: return {"status": "ok"} @router.get("/ready") async def ready(request: Request) -> JSONResponse: cfg: Config = request.app.state.cfg registry: Registry = request.app.state.registry upstream: UpstreamClient = request.app.state.upstream statuses = await _collect_model_statuses(cfg, registry, upstream) default_status = statuses.get(cfg.default_model) available_aliases = [alias for alias, status in statuses.items() if status["available"]] if default_status and default_status["available"]: return JSONResponse( status_code=200, content={ "status": "ready", "default_model": cfg.default_model, "available_models": available_aliases, "unavailable_models": [s for s in statuses.values() if not s["available"]], }, ) return JSONResponse( status_code=503, content={ "status": "not_ready", "default_model": cfg.default_model, "available_models": available_aliases, "unavailable_models": [s for s in statuses.values() if not s["available"]], }, ) @router.post("/v1/nodes/register") async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]: """ Allow a remote machine to register which roles it serves. SECURITY NOTE: This endpoint is unauthenticated in the scaffold. If the gateway is reachable by untrusted clients, add API-key gating or mTLS. """ registry: Registry = request.app.state.registry payload = await request.json() reg = NodeRegistration.model_validate(payload) node = registry.register(reg) return {"status": "ok", "node": node.model_dump(mode="json")} @router.post("/v1/nodes/heartbeat") async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]: """ Allow a node agent to push status/metrics updates. SECURITY NOTE: Unauthenticated in scaffold. Add API-key gating or mTLS for real deployments. """ registry: Registry = request.app.state.registry payload = await request.json() hb = NodeHeartbeat.model_validate(payload) node = registry.heartbeat(hb) if not node: return JSONResponse(status_code=404, content={"error": "unknown_node", "node_id": hb.node_id}) return {"status": "ok", "node": node.model_dump(mode="json")}