RoleMesh-Gateway/src/rolemesh_gateway/api/openai.py

265 lines
9.0 KiB
Python

from __future__ import annotations
import asyncio
import time
import uuid
from typing import Any, Dict, List
from fastapi import APIRouter, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
from rolemesh_gateway.auth import require_client_auth, require_node_auth
router = APIRouter()
def _openai_error(message: str, code: str = "upstream_error", status_code: int = 502) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content={
"error": {
"message": message,
"type": "gateway_error",
"param": None,
"code": code,
}
},
)
async def _model_status(
alias: str,
entry: ProxyModel | DiscoveredModel,
registry: Registry,
upstream: UpstreamClient,
) -> Dict[str, Any]:
if isinstance(entry, ProxyModel):
base_url = str(entry.proxy_url).rstrip("/")
try:
await upstream.get_models(base_url)
return {"alias": alias, "available": True, "base_url": base_url}
except UpstreamError as exc:
return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)}
matching_nodes = registry.nodes_for_role(entry.role, include_stale=False)
if not matching_nodes:
stale_nodes = registry.nodes_for_role(entry.role, include_stale=True)
error = "no_registered_nodes"
if stale_nodes:
error = "no_fresh_registered_nodes"
return {"alias": alias, "available": False, "role": entry.role, "error": error}
for node in matching_nodes:
base_url = str(node.base_url).rstrip("/")
try:
await upstream.get_models(base_url)
return {
"alias": alias,
"available": True,
"role": entry.role,
"base_url": base_url,
"node_id": node.node_id,
}
except UpstreamError:
continue
return {
"alias": alias,
"available": False,
"role": entry.role,
"error": "no_healthy_registered_nodes",
}
async def _collect_model_statuses(
cfg: Config,
registry: Registry,
upstream: UpstreamClient,
) -> Dict[str, Dict[str, Any]]:
statuses = await asyncio.gather(
*[
_model_status(alias, entry, registry, upstream)
for alias, entry in cfg.models.items()
]
)
return {status["alias"]: status for status in statuses}
@router.get("/v1/models")
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
cfg: Config = request.app.state.cfg
registry: Registry = request.app.state.registry
upstream: UpstreamClient = request.app.state.upstream
statuses = await _collect_model_statuses(cfg, registry, upstream)
data = []
for name, entry in cfg.models.items():
status = statuses[name]
if not status["available"]:
continue
item = {
"id": entry.openai_model_name,
"object": "model",
"owned_by": "local",
}
if isinstance(entry, ProxyModel):
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True}
else:
item["rolemesh"] = {
"type": "discovered",
"role": entry.role,
"strategy": entry.strategy,
"available": True,
}
data.append(item)
# Expose currently registered nodes (informational)
nodes = [
n.model_dump(mode="json") | {"stale": registry.is_stale(n)}
for n in registry.list_nodes(include_stale=True)
]
unavailable = [status for status in statuses.values() if not status["available"]]
return {
"object": "list",
"data": data,
"rolemesh": {
"registered_nodes": nodes,
"unavailable_models": unavailable,
},
}
@router.post("/v1/chat/completions")
async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
cfg: Config = request.app.state.cfg
upstream: UpstreamClient = request.app.state.upstream
registry: Registry = request.app.state.registry
req_id = str(uuid.uuid4())
started = time.time()
body = await request.json()
model = body.get("model") or cfg.default_model
stream = bool(body.get("stream", False))
if model not in cfg.models:
return _openai_error(f"Unknown model '{model}'. Check GET /v1/models.", code="unknown_model", status_code=400)
entry = cfg.models[model]
# Apply per-model defaults if request didn't specify those keys.
defaults = entry.defaults if hasattr(entry, "defaults") else {}
for k, v in defaults.items():
body.setdefault(k, v)
# Resolve upstream base URL
if isinstance(entry, ProxyModel):
base_url = str(entry.proxy_url).rstrip("/")
elif isinstance(entry, DiscoveredModel):
node = registry.pick_node_for_role(entry.role, strategy=entry.strategy)
if not node:
return _openai_error(
f"No registered nodes available for role '{entry.role}'. "
f"Register a node via POST /v1/nodes/register, or use proxy mode.",
code="no_upstream",
status_code=503,
)
base_url = str(node.base_url).rstrip("/")
else:
return _openai_error("Invalid model configuration.", code="bad_config", status_code=500)
# Proxy request
try:
if not stream:
out = await upstream.chat_completions(base_url, body)
request.app.logger.info(
"chat_completions ok",
extra={"req_id": req_id, "model": model, "upstream": base_url, "ms": int(1000*(time.time()-started))},
)
return out
# streaming: passthrough bytes (SSE)
async def gen():
async for chunk in upstream.chat_completions_stream(base_url, body):
yield chunk
return StreamingResponse(gen(), media_type="text/event-stream")
except UpstreamError as e:
request.app.logger.warning(
"chat_completions upstream error",
extra={"req_id": req_id, "model": model, "upstream": base_url, "err": str(e)},
)
return _openai_error(str(e), code="upstream_error", status_code=e.status_code or 502)
@router.get("/health")
async def health() -> Dict[str, str]:
return {"status": "ok"}
@router.get("/ready")
async def ready(request: Request) -> JSONResponse:
cfg: Config = request.app.state.cfg
registry: Registry = request.app.state.registry
upstream: UpstreamClient = request.app.state.upstream
statuses = await _collect_model_statuses(cfg, registry, upstream)
default_status = statuses.get(cfg.default_model)
available_aliases = [alias for alias, status in statuses.items() if status["available"]]
if default_status and default_status["available"]:
return JSONResponse(
status_code=200,
content={
"status": "ready",
"default_model": cfg.default_model,
"available_models": available_aliases,
"unavailable_models": [s for s in statuses.values() if not s["available"]],
},
)
return JSONResponse(
status_code=503,
content={
"status": "not_ready",
"default_model": cfg.default_model,
"available_models": available_aliases,
"unavailable_models": [s for s in statuses.values() if not s["available"]],
},
)
@router.post("/v1/nodes/register")
async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
"""
Allow a remote machine to register which roles it serves.
SECURITY NOTE: This endpoint is unauthenticated in the scaffold. If the gateway is reachable by
untrusted clients, add API-key gating or mTLS.
"""
registry: Registry = request.app.state.registry
payload = await request.json()
reg = NodeRegistration.model_validate(payload)
node = registry.register(reg)
return {"status": "ok", "node": node.model_dump(mode="json")}
@router.post("/v1/nodes/heartbeat")
async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
"""
Allow a node agent to push status/metrics updates.
SECURITY NOTE: Unauthenticated in scaffold. Add API-key gating or mTLS for real deployments.
"""
registry: Registry = request.app.state.registry
payload = await request.json()
hb = NodeHeartbeat.model_validate(payload)
node = registry.heartbeat(hb)
if not node:
return JSONResponse(status_code=404, content={"error": "unknown_node", "node_id": hb.node_id})
return {"status": "ok", "node": node.model_dump(mode="json")}