RoleMesh-Gateway/src/rolemesh_node_agent/main.py

154 lines
5.7 KiB
Python

from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager, suppress
import time
from typing import Any, Dict
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
from .adapters.cuda import CudaAdapter, ServerStartupError
from .adapters.base import DeviceRef
from .config import NodeAgentConfig
from .inventory import discover_gguf_models
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content={"error": {"message": message, "type": "node_error", "code": code}},
)
def create_app(cfg: NodeAgentConfig) -> FastAPI:
http = httpx.AsyncClient(
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
)
upstream = UpstreamClient(client=http)
cuda = CudaAdapter(
llama_server_bin=cfg.llama_server_bin,
startup_timeout_s=cfg.llama_server_startup_timeout_s,
probe_interval_s=cfg.llama_server_probe_interval_s,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
heartbeat_task: asyncio.Task[None] | None = None
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
try:
yield
finally:
if heartbeat_task is not None:
heartbeat_task.cancel()
with suppress(asyncio.CancelledError):
await heartbeat_task
await http.aclose()
await cuda.shutdown()
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan)
app.state.cfg = cfg
app.state.http = http
app.state.upstream = upstream
# Adapters
app.state.cuda = cuda
# State: role -> (device, model)
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
@app.get("/health")
async def health() -> Dict[str, Any]:
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
@app.get("/v1/node/inventory")
async def inventory() -> Dict[str, Any]:
devices = await app.state.cuda.discover_devices()
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
discovered = discover_gguf_models(cfg.model_roots)
return {
"node_id": cfg.node_id,
"backends": cfg.enable_backends,
"devices": [d.__dict__ for d in devices],
"metrics": metrics,
"models": models,
"discovered_gguf": discovered,
}
@app.get("/v1/models")
async def list_models() -> Dict[str, Any]:
# Expose configured models as OpenAI models (node-local names).
data = [{"id": m.model_id, "object": "model", "owned_by": cfg.node_id} for m in cfg.models]
return {"object": "list", "data": data}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Any:
body = await request.json()
stream = bool(body.get("stream", False))
model_id = body.get("model")
if not model_id:
return _error("Missing 'model' in request body.", code="bad_request", status_code=400)
# Find model entry
model_entry = next((m for m in cfg.models if m.model_id == model_id), None)
if not model_entry:
return _error(f"Unknown model_id '{model_id}'.", code="unknown_model", status_code=404)
# Select device (first CUDA GPU for now)
devices = await app.state.cuda.discover_devices()
if not devices:
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
device = devices[0]
try:
base_url = await app.state.cuda.ensure_server(
device,
model_path=str(model_entry.path),
model_id=model_entry.model_id,
server_args=model_entry.server_args,
)
except ServerStartupError as e:
return _error(str(e), code="server_startup_error", status_code=503)
upstream = app.state.upstream
try:
if not stream:
out = await upstream.chat_completions(base_url, body)
return JSONResponse(status_code=200, content=out)
else:
async def gen():
async for chunk in upstream.stream_chat_completions(base_url, body):
yield chunk
return StreamingResponse(gen(), media_type="text/event-stream")
except UpstreamError as e:
return _error(str(e), code="upstream_error", status_code=502)
return app
async def _heartbeat_loop(app: FastAPI) -> None:
cfg: NodeAgentConfig = app.state.cfg
http: httpx.AsyncClient = app.state.http
while True:
try:
inv = await app.state.cuda.get_metrics()
payload = {
"node_id": cfg.node_id,
"timestamp": time.time(),
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
}
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
headers = {}
if cfg.dispatcher_node_key:
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
await http.post(url, json=payload, headers=headers)
except Exception:
pass
await asyncio.sleep(cfg.heartbeat_interval_sec)