154 lines
5.7 KiB
Python
154 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from contextlib import asynccontextmanager, suppress
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
|
|
from .adapters.cuda import CudaAdapter, ServerStartupError
|
|
from .adapters.base import DeviceRef
|
|
from .config import NodeAgentConfig
|
|
from .inventory import discover_gguf_models
|
|
|
|
|
|
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=status_code,
|
|
content={"error": {"message": message, "type": "node_error", "code": code}},
|
|
)
|
|
|
|
|
|
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|
http = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
|
|
)
|
|
upstream = UpstreamClient(client=http)
|
|
cuda = CudaAdapter(
|
|
llama_server_bin=cfg.llama_server_bin,
|
|
startup_timeout_s=cfg.llama_server_startup_timeout_s,
|
|
probe_interval_s=cfg.llama_server_probe_interval_s,
|
|
)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
heartbeat_task: asyncio.Task[None] | None = None
|
|
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
|
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
|
|
try:
|
|
yield
|
|
finally:
|
|
if heartbeat_task is not None:
|
|
heartbeat_task.cancel()
|
|
with suppress(asyncio.CancelledError):
|
|
await heartbeat_task
|
|
await http.aclose()
|
|
await cuda.shutdown()
|
|
|
|
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan)
|
|
|
|
app.state.cfg = cfg
|
|
app.state.http = http
|
|
app.state.upstream = upstream
|
|
|
|
# Adapters
|
|
app.state.cuda = cuda
|
|
|
|
# State: role -> (device, model)
|
|
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
|
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
|
|
|
@app.get("/health")
|
|
async def health() -> Dict[str, Any]:
|
|
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
|
|
|
@app.get("/v1/node/inventory")
|
|
async def inventory() -> Dict[str, Any]:
|
|
devices = await app.state.cuda.discover_devices()
|
|
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
|
|
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
|
|
discovered = discover_gguf_models(cfg.model_roots)
|
|
return {
|
|
"node_id": cfg.node_id,
|
|
"backends": cfg.enable_backends,
|
|
"devices": [d.__dict__ for d in devices],
|
|
"metrics": metrics,
|
|
"models": models,
|
|
"discovered_gguf": discovered,
|
|
}
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models() -> Dict[str, Any]:
|
|
# Expose configured models as OpenAI models (node-local names).
|
|
data = [{"id": m.model_id, "object": "model", "owned_by": cfg.node_id} for m in cfg.models]
|
|
return {"object": "list", "data": data}
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request) -> Any:
|
|
body = await request.json()
|
|
stream = bool(body.get("stream", False))
|
|
model_id = body.get("model")
|
|
if not model_id:
|
|
return _error("Missing 'model' in request body.", code="bad_request", status_code=400)
|
|
|
|
# Find model entry
|
|
model_entry = next((m for m in cfg.models if m.model_id == model_id), None)
|
|
if not model_entry:
|
|
return _error(f"Unknown model_id '{model_id}'.", code="unknown_model", status_code=404)
|
|
|
|
# Select device (first CUDA GPU for now)
|
|
devices = await app.state.cuda.discover_devices()
|
|
if not devices:
|
|
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
|
|
device = devices[0]
|
|
|
|
try:
|
|
base_url = await app.state.cuda.ensure_server(
|
|
device,
|
|
model_path=str(model_entry.path),
|
|
model_id=model_entry.model_id,
|
|
server_args=model_entry.server_args,
|
|
)
|
|
except ServerStartupError as e:
|
|
return _error(str(e), code="server_startup_error", status_code=503)
|
|
|
|
upstream = app.state.upstream
|
|
try:
|
|
if not stream:
|
|
out = await upstream.chat_completions(base_url, body)
|
|
return JSONResponse(status_code=200, content=out)
|
|
else:
|
|
async def gen():
|
|
async for chunk in upstream.stream_chat_completions(base_url, body):
|
|
yield chunk
|
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
|
except UpstreamError as e:
|
|
return _error(str(e), code="upstream_error", status_code=502)
|
|
|
|
return app
|
|
|
|
|
|
async def _heartbeat_loop(app: FastAPI) -> None:
|
|
cfg: NodeAgentConfig = app.state.cfg
|
|
http: httpx.AsyncClient = app.state.http
|
|
while True:
|
|
try:
|
|
inv = await app.state.cuda.get_metrics()
|
|
payload = {
|
|
"node_id": cfg.node_id,
|
|
"timestamp": time.time(),
|
|
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
|
}
|
|
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
|
headers = {}
|
|
if cfg.dispatcher_node_key:
|
|
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
|
|
await http.post(url, json=payload, headers=headers)
|
|
except Exception:
|
|
pass
|
|
await asyncio.sleep(cfg.heartbeat_interval_sec)
|