265 lines
9.0 KiB
Python
265 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List
|
|
|
|
from fastapi import APIRouter, Request, Depends
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
|
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
|
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
|
from rolemesh_gateway.auth import require_client_auth, require_node_auth
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
def _openai_error(message: str, code: str = "upstream_error", status_code: int = 502) -> JSONResponse:
|
|
return JSONResponse(
|
|
status_code=status_code,
|
|
content={
|
|
"error": {
|
|
"message": message,
|
|
"type": "gateway_error",
|
|
"param": None,
|
|
"code": code,
|
|
}
|
|
},
|
|
)
|
|
|
|
|
|
async def _model_status(
|
|
alias: str,
|
|
entry: ProxyModel | DiscoveredModel,
|
|
registry: Registry,
|
|
upstream: UpstreamClient,
|
|
) -> Dict[str, Any]:
|
|
if isinstance(entry, ProxyModel):
|
|
base_url = str(entry.proxy_url).rstrip("/")
|
|
try:
|
|
await upstream.get_models(base_url)
|
|
return {"alias": alias, "available": True, "base_url": base_url}
|
|
except UpstreamError as exc:
|
|
return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)}
|
|
|
|
matching_nodes = registry.nodes_for_role(entry.role, include_stale=False)
|
|
if not matching_nodes:
|
|
stale_nodes = registry.nodes_for_role(entry.role, include_stale=True)
|
|
error = "no_registered_nodes"
|
|
if stale_nodes:
|
|
error = "no_fresh_registered_nodes"
|
|
return {"alias": alias, "available": False, "role": entry.role, "error": error}
|
|
|
|
for node in matching_nodes:
|
|
base_url = str(node.base_url).rstrip("/")
|
|
try:
|
|
await upstream.get_models(base_url)
|
|
return {
|
|
"alias": alias,
|
|
"available": True,
|
|
"role": entry.role,
|
|
"base_url": base_url,
|
|
"node_id": node.node_id,
|
|
}
|
|
except UpstreamError:
|
|
continue
|
|
|
|
return {
|
|
"alias": alias,
|
|
"available": False,
|
|
"role": entry.role,
|
|
"error": "no_healthy_registered_nodes",
|
|
}
|
|
|
|
|
|
async def _collect_model_statuses(
|
|
cfg: Config,
|
|
registry: Registry,
|
|
upstream: UpstreamClient,
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
statuses = await asyncio.gather(
|
|
*[
|
|
_model_status(alias, entry, registry, upstream)
|
|
for alias, entry in cfg.models.items()
|
|
]
|
|
)
|
|
return {status["alias"]: status for status in statuses}
|
|
|
|
|
|
@router.get("/v1/models")
|
|
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
|
|
cfg: Config = request.app.state.cfg
|
|
registry: Registry = request.app.state.registry
|
|
upstream: UpstreamClient = request.app.state.upstream
|
|
statuses = await _collect_model_statuses(cfg, registry, upstream)
|
|
|
|
data = []
|
|
for name, entry in cfg.models.items():
|
|
status = statuses[name]
|
|
if not status["available"]:
|
|
continue
|
|
item = {
|
|
"id": entry.openai_model_name,
|
|
"object": "model",
|
|
"owned_by": "local",
|
|
}
|
|
if isinstance(entry, ProxyModel):
|
|
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True}
|
|
else:
|
|
item["rolemesh"] = {
|
|
"type": "discovered",
|
|
"role": entry.role,
|
|
"strategy": entry.strategy,
|
|
"available": True,
|
|
}
|
|
data.append(item)
|
|
|
|
# Expose currently registered nodes (informational)
|
|
nodes = [
|
|
n.model_dump(mode="json") | {"stale": registry.is_stale(n)}
|
|
for n in registry.list_nodes(include_stale=True)
|
|
]
|
|
unavailable = [status for status in statuses.values() if not status["available"]]
|
|
|
|
return {
|
|
"object": "list",
|
|
"data": data,
|
|
"rolemesh": {
|
|
"registered_nodes": nodes,
|
|
"unavailable_models": unavailable,
|
|
},
|
|
}
|
|
|
|
|
|
@router.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
|
|
cfg: Config = request.app.state.cfg
|
|
upstream: UpstreamClient = request.app.state.upstream
|
|
registry: Registry = request.app.state.registry
|
|
|
|
req_id = str(uuid.uuid4())
|
|
started = time.time()
|
|
|
|
body = await request.json()
|
|
model = body.get("model") or cfg.default_model
|
|
stream = bool(body.get("stream", False))
|
|
|
|
if model not in cfg.models:
|
|
return _openai_error(f"Unknown model '{model}'. Check GET /v1/models.", code="unknown_model", status_code=400)
|
|
|
|
entry = cfg.models[model]
|
|
|
|
# Apply per-model defaults if request didn't specify those keys.
|
|
defaults = entry.defaults if hasattr(entry, "defaults") else {}
|
|
for k, v in defaults.items():
|
|
body.setdefault(k, v)
|
|
|
|
# Resolve upstream base URL
|
|
if isinstance(entry, ProxyModel):
|
|
base_url = str(entry.proxy_url).rstrip("/")
|
|
elif isinstance(entry, DiscoveredModel):
|
|
node = registry.pick_node_for_role(entry.role, strategy=entry.strategy)
|
|
if not node:
|
|
return _openai_error(
|
|
f"No registered nodes available for role '{entry.role}'. "
|
|
f"Register a node via POST /v1/nodes/register, or use proxy mode.",
|
|
code="no_upstream",
|
|
status_code=503,
|
|
)
|
|
base_url = str(node.base_url).rstrip("/")
|
|
else:
|
|
return _openai_error("Invalid model configuration.", code="bad_config", status_code=500)
|
|
|
|
# Proxy request
|
|
try:
|
|
if not stream:
|
|
out = await upstream.chat_completions(base_url, body)
|
|
request.app.logger.info(
|
|
"chat_completions ok",
|
|
extra={"req_id": req_id, "model": model, "upstream": base_url, "ms": int(1000*(time.time()-started))},
|
|
)
|
|
return out
|
|
|
|
# streaming: passthrough bytes (SSE)
|
|
async def gen():
|
|
async for chunk in upstream.chat_completions_stream(base_url, body):
|
|
yield chunk
|
|
|
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
|
|
|
except UpstreamError as e:
|
|
request.app.logger.warning(
|
|
"chat_completions upstream error",
|
|
extra={"req_id": req_id, "model": model, "upstream": base_url, "err": str(e)},
|
|
)
|
|
return _openai_error(str(e), code="upstream_error", status_code=e.status_code or 502)
|
|
|
|
|
|
@router.get("/health")
|
|
async def health() -> Dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/ready")
|
|
async def ready(request: Request) -> JSONResponse:
|
|
cfg: Config = request.app.state.cfg
|
|
registry: Registry = request.app.state.registry
|
|
upstream: UpstreamClient = request.app.state.upstream
|
|
statuses = await _collect_model_statuses(cfg, registry, upstream)
|
|
default_status = statuses.get(cfg.default_model)
|
|
available_aliases = [alias for alias, status in statuses.items() if status["available"]]
|
|
|
|
if default_status and default_status["available"]:
|
|
return JSONResponse(
|
|
status_code=200,
|
|
content={
|
|
"status": "ready",
|
|
"default_model": cfg.default_model,
|
|
"available_models": available_aliases,
|
|
"unavailable_models": [s for s in statuses.values() if not s["available"]],
|
|
},
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={
|
|
"status": "not_ready",
|
|
"default_model": cfg.default_model,
|
|
"available_models": available_aliases,
|
|
"unavailable_models": [s for s in statuses.values() if not s["available"]],
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/v1/nodes/register")
|
|
async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
|
"""
|
|
Allow a remote machine to register which roles it serves.
|
|
|
|
SECURITY NOTE: This endpoint is unauthenticated in the scaffold. If the gateway is reachable by
|
|
untrusted clients, add API-key gating or mTLS.
|
|
"""
|
|
registry: Registry = request.app.state.registry
|
|
payload = await request.json()
|
|
reg = NodeRegistration.model_validate(payload)
|
|
node = registry.register(reg)
|
|
return {"status": "ok", "node": node.model_dump(mode="json")}
|
|
|
|
|
|
@router.post("/v1/nodes/heartbeat")
|
|
async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
|
"""
|
|
Allow a node agent to push status/metrics updates.
|
|
|
|
SECURITY NOTE: Unauthenticated in scaffold. Add API-key gating or mTLS for real deployments.
|
|
"""
|
|
registry: Registry = request.app.state.registry
|
|
payload = await request.json()
|
|
hb = NodeHeartbeat.model_validate(payload)
|
|
node = registry.heartbeat(hb)
|
|
if not node:
|
|
return JSONResponse(status_code=404, content={"error": "unknown_node", "node_id": hb.node_id})
|
|
return {"status": "ok", "node": node.model_dump(mode="json")}
|