from __future__ import annotations import asyncio from contextlib import asynccontextmanager, suppress import time from typing import Any, Dict import httpx from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client from .adapters.cuda import CudaAdapter, ServerStartupError from .adapters.base import DeviceRef from .config import NodeAgentConfig from .inventory import discover_gguf_models def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse: return JSONResponse( status_code=status_code, content={"error": {"message": message, "type": "node_error", "code": code}}, ) def create_app(cfg: NodeAgentConfig) -> FastAPI: http = httpx.AsyncClient( timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0) ) upstream = UpstreamClient(client=http) cuda = CudaAdapter( llama_server_bin=cfg.llama_server_bin, startup_timeout_s=cfg.llama_server_startup_timeout_s, probe_interval_s=cfg.llama_server_probe_interval_s, ) @asynccontextmanager async def lifespan(app: FastAPI): heartbeat_task: asyncio.Task[None] | None = None if cfg.dispatcher_base_url and cfg.dispatcher_roles: heartbeat_task = asyncio.create_task(_heartbeat_loop(app)) try: yield finally: if heartbeat_task is not None: heartbeat_task.cancel() with suppress(asyncio.CancelledError): await heartbeat_task await http.aclose() await cuda.shutdown() app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan) app.state.cfg = cfg app.state.http = http app.state.upstream = upstream # Adapters app.state.cuda = cuda # State: role -> (device, model) # This is intentionally simple for the scaffold: pick first GPU and first matching model. app.state.role_bindings: Dict[str, Dict[str, str]] = {} @app.get("/health") async def health() -> Dict[str, Any]: return {"status": "ok", "node_id": cfg.node_id, "time": time.time()} @app.get("/v1/node/inventory") async def inventory() -> Dict[str, Any]: devices = await app.state.cuda.discover_devices() metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()] models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models] discovered = discover_gguf_models(cfg.model_roots) return { "node_id": cfg.node_id, "backends": cfg.enable_backends, "devices": [d.__dict__ for d in devices], "metrics": metrics, "models": models, "discovered_gguf": discovered, } @app.get("/v1/models") async def list_models() -> Dict[str, Any]: # Expose configured models as OpenAI models (node-local names). data = [{"id": m.model_id, "object": "model", "owned_by": cfg.node_id} for m in cfg.models] return {"object": "list", "data": data} @app.post("/v1/chat/completions") async def chat_completions(request: Request) -> Any: body = await request.json() stream = bool(body.get("stream", False)) model_id = body.get("model") if not model_id: return _error("Missing 'model' in request body.", code="bad_request", status_code=400) # Find model entry model_entry = next((m for m in cfg.models if m.model_id == model_id), None) if not model_entry: return _error(f"Unknown model_id '{model_id}'.", code="unknown_model", status_code=404) # Select device (first CUDA GPU for now) devices = await app.state.cuda.discover_devices() if not devices: return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503) device = devices[0] try: base_url = await app.state.cuda.ensure_server( device, model_path=str(model_entry.path), model_id=model_entry.model_id, server_args=model_entry.server_args, ) except ServerStartupError as e: return _error(str(e), code="server_startup_error", status_code=503) upstream = app.state.upstream try: if not stream: out = await upstream.chat_completions(base_url, body) return JSONResponse(status_code=200, content=out) else: async def gen(): async for chunk in upstream.stream_chat_completions(base_url, body): yield chunk return StreamingResponse(gen(), media_type="text/event-stream") except UpstreamError as e: return _error(str(e), code="upstream_error", status_code=502) return app async def _heartbeat_loop(app: FastAPI) -> None: cfg: NodeAgentConfig = app.state.cfg http: httpx.AsyncClient = app.state.http while True: try: inv = await app.state.cuda.get_metrics() payload = { "node_id": cfg.node_id, "timestamp": time.time(), "metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv], } url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat" headers = {} if cfg.dispatcher_node_key: headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key await http.post(url, json=payload, headers=headers) except Exception: pass await asyncio.sleep(cfg.heartbeat_interval_sec)