128 lines
5.2 KiB
Python
128 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends, FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from .auth import require_client_auth, require_node_auth
|
|
from .chat import ProxyError, proxy_chat_completion, proxy_embeddings
|
|
from .config import ControlConfig, load_config
|
|
from .models import HostHeartbeat, HostRegistration
|
|
from .roles import load_role_catalog
|
|
from .registry import Registry
|
|
from .upstream import UpstreamClient, UpstreamError
|
|
|
|
|
|
def create_app(
|
|
config_path: str | Path | None = None,
|
|
*,
|
|
upstream_client: UpstreamClient | None = None,
|
|
) -> FastAPI:
|
|
cfg_path = config_path or os.environ.get("GENIEHIVE_CONTROL_CONFIG")
|
|
cfg = load_config(cfg_path) if cfg_path else ControlConfig()
|
|
registry = Registry(cfg.storage.sqlite_path)
|
|
roles_path = cfg.roles_path or os.environ.get("GENIEHIVE_ROLES_CONFIG")
|
|
if roles_path:
|
|
registry.upsert_roles(load_role_catalog(roles_path).roles)
|
|
upstream = upstream_client or UpstreamClient()
|
|
|
|
app = FastAPI(title="GenieHive Control", version="0.1.0")
|
|
app.state.cfg = cfg
|
|
app.state.registry = registry
|
|
app.state.upstream = upstream
|
|
|
|
@app.get("/health")
|
|
async def health() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@app.post("/v1/nodes/register")
|
|
async def register_node(request: Request, _=Depends(require_node_auth)) -> dict:
|
|
payload = await request.json()
|
|
reg = HostRegistration.model_validate(payload)
|
|
host = request.app.state.registry.register_host(reg)
|
|
return {"status": "ok", "host": host}
|
|
|
|
@app.post("/v1/nodes/heartbeat")
|
|
async def heartbeat_node(request: Request, _=Depends(require_node_auth)):
|
|
payload = await request.json()
|
|
hb = HostHeartbeat.model_validate(payload)
|
|
host = request.app.state.registry.heartbeat_host(hb)
|
|
if host is None:
|
|
return JSONResponse(status_code=404, content={"error": "unknown_host", "host_id": hb.host_id})
|
|
return {"status": "ok", "host": host}
|
|
|
|
@app.get("/v1/cluster/hosts")
|
|
async def list_hosts(request: Request, _=Depends(require_client_auth)) -> dict:
|
|
return {"object": "list", "data": request.app.state.registry.list_hosts()}
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models(request: Request, _=Depends(require_client_auth)) -> dict:
|
|
return {"object": "list", "data": request.app.state.registry.list_client_models()}
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request, _=Depends(require_client_auth)):
|
|
body = await request.json()
|
|
try:
|
|
return await proxy_chat_completion(
|
|
body,
|
|
registry=request.app.state.registry,
|
|
upstream=request.app.state.upstream,
|
|
)
|
|
except ProxyError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "chat_proxy_error"}},
|
|
)
|
|
except UpstreamError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code or 502,
|
|
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
|
|
)
|
|
|
|
@app.post("/v1/embeddings")
|
|
async def embeddings(request: Request, _=Depends(require_client_auth)):
|
|
body = await request.json()
|
|
try:
|
|
return await proxy_embeddings(
|
|
body,
|
|
registry=request.app.state.registry,
|
|
upstream=request.app.state.upstream,
|
|
)
|
|
except ProxyError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "embeddings_proxy_error"}},
|
|
)
|
|
except UpstreamError as exc:
|
|
return JSONResponse(
|
|
status_code=exc.status_code or 502,
|
|
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
|
|
)
|
|
|
|
@app.get("/v1/cluster/services")
|
|
async def list_services(request: Request, _=Depends(require_client_auth)) -> dict:
|
|
return {"object": "list", "data": request.app.state.registry.list_services()}
|
|
|
|
@app.get("/v1/cluster/roles")
|
|
async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict:
|
|
return {"object": "list", "data": request.app.state.registry.list_roles()}
|
|
|
|
@app.get("/v1/cluster/health")
|
|
async def cluster_health(request: Request, _=Depends(require_client_auth)) -> dict:
|
|
cfg: ControlConfig = request.app.state.cfg
|
|
return request.app.state.registry.cluster_health(cfg.routing.health_stale_after_s)
|
|
|
|
@app.get("/v1/cluster/routes/resolve")
|
|
async def resolve_route(model: str, request: Request, kind: str | None = None, _=Depends(require_client_auth)) -> dict:
|
|
resolved = request.app.state.registry.resolve_route(model, kind=kind)
|
|
if resolved is None:
|
|
return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind})
|
|
return {"status": "ok", "resolution": resolved}
|
|
|
|
return app
|
|
|
|
|
|
app = create_app()
|