Applying Codex changes, added test code.
This commit is contained in:
parent
1908b42499
commit
87fcdaaacc
|
|
@ -13,3 +13,4 @@ build/
|
|||
.vscode/
|
||||
state/registry.json
|
||||
configs/models.yaml
|
||||
tmp-codex/
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@ def _key_id(token: str) -> str:
|
|||
return token[-6:] if len(token) >= 6 else token
|
||||
|
||||
|
||||
def require_client_auth(
|
||||
async def require_client_auth(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||
x_api_key: Optional[str] = Header(default=None, alias="X-Api-Key"),
|
||||
) -> Optional[Principal]:
|
||||
"""Gate OpenAI-style endpoints with an API key.
|
||||
|
||||
|
|
@ -56,10 +56,10 @@ def require_client_auth(
|
|||
return Principal(kind="client", key_id=_key_id(token))
|
||||
|
||||
|
||||
def require_node_auth(
|
||||
async def require_node_auth(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||
x_rolemesh_node_key: Optional[str] = Header(default=None, alias="X-RoleMesh-Node-Key"),
|
||||
) -> Optional[Principal]:
|
||||
"""Gate node registration/heartbeat endpoints with a node key.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from rolemesh_gateway.config import load_config
|
||||
|
|
@ -24,20 +24,30 @@ def _get_logger() -> logging.Logger:
|
|||
return logger
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="RoleMesh Gateway", version="0.1.0")
|
||||
|
||||
cfg_path = os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
||||
def create_app(
|
||||
config_path: str | Path | None = None,
|
||||
registry_path: str | Path | None = None,
|
||||
) -> FastAPI:
|
||||
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
||||
cfg = load_config(cfg_path)
|
||||
|
||||
registry_path = os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
||||
registry = Registry(persist_path=Path(registry_path))
|
||||
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
||||
registry = Registry(persist_path=Path(resolved_registry_path))
|
||||
|
||||
upstream = UpstreamClient(
|
||||
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),
|
||||
read_timeout_s=float(os.environ.get("ROLE_MESH_READ_TIMEOUT_S", "600")),
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await upstream.close()
|
||||
|
||||
app = FastAPI(title="RoleMesh Gateway", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.registry = registry
|
||||
app.state.upstream = upstream
|
||||
|
|
@ -45,10 +55,6 @@ def create_app() -> FastAPI:
|
|||
|
||||
app.include_router(openai_router)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown():
|
||||
await upstream.close()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ class NodeHeartbeat(BaseModel):
|
|||
timestamp: float = Field(default_factory=lambda: time.time())
|
||||
status: Dict[str, Any] = Field(default_factory=dict)
|
||||
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RegisteredNode(BaseModel):
|
||||
node_id: str
|
||||
base_url: HttpUrl
|
||||
|
|
@ -85,11 +87,18 @@ class Registry:
|
|||
self._save()
|
||||
return node
|
||||
|
||||
def heartbeat(self, node_id: str) -> None:
|
||||
n = self._nodes.get(node_id)
|
||||
if n:
|
||||
def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]:
|
||||
n = self._nodes.get(hb.node_id)
|
||||
if not n:
|
||||
return None
|
||||
n.status = {
|
||||
**hb.status,
|
||||
"timestamp": hb.timestamp,
|
||||
"metrics": hb.metrics,
|
||||
}
|
||||
n.last_seen = time.time()
|
||||
self._save()
|
||||
return n
|
||||
|
||||
def list_nodes(self) -> List[RegisteredNode]:
|
||||
return list(self._nodes.values())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
|
@ -18,10 +17,17 @@ def _timeout(connect_s: float, read_s: float) -> httpx.Timeout:
|
|||
|
||||
|
||||
class UpstreamClient:
|
||||
def __init__(self, connect_timeout_s: float = 10.0, read_timeout_s: float = 600.0) -> None:
|
||||
self._client = httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
|
||||
def __init__(
|
||||
self,
|
||||
connect_timeout_s: float = 10.0,
|
||||
read_timeout_s: float = 600.0,
|
||||
client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._owns_client = client is None
|
||||
self._client = client or httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._owns_client:
|
||||
await self._client.aclose()
|
||||
|
||||
async def get_models(self, base_url: str) -> Dict[str, Any]:
|
||||
|
|
@ -58,3 +64,7 @@ class UpstreamClient:
|
|||
yield chunk
|
||||
except httpx.RequestError as e:
|
||||
raise UpstreamError(f"Upstream unreachable: {e!s}") from e
|
||||
|
||||
async def stream_chat_completions(self, base_url: str, payload: Dict[str, Any]) -> AsyncIterator[bytes]:
|
||||
async for chunk in self.chat_completions_stream(base_url, payload):
|
||||
yield chunk
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
|
|
@ -24,30 +24,40 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS
|
|||
|
||||
|
||||
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
||||
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0")
|
||||
http = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
|
||||
)
|
||||
upstream = UpstreamClient(client=http)
|
||||
cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
heartbeat_task: asyncio.Task[None] | None = None
|
||||
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if heartbeat_task is not None:
|
||||
heartbeat_task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await heartbeat_task
|
||||
await http.aclose()
|
||||
await cuda.shutdown()
|
||||
|
||||
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0))
|
||||
app.state.upstream = UpstreamClient(client=app.state.http)
|
||||
app.state.http = http
|
||||
app.state.upstream = upstream
|
||||
|
||||
# Adapters
|
||||
app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
||||
app.state.cuda = cuda
|
||||
|
||||
# State: role -> (device, model)
|
||||
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
||||
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
# optional: dispatcher registration loop
|
||||
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||
asyncio.create_task(_heartbeat_loop(app))
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
await app.state.http.aclose()
|
||||
await app.state.cuda.shutdown()
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SRC = ROOT / "src"
|
||||
|
||||
if str(SRC) not in sys.path:
|
||||
sys.path.insert(0, str(SRC))
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
|
||||
def _write_gateway_config(path: Path) -> Path:
|
||||
path.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"version": 1,
|
||||
"default_model": "writer",
|
||||
"auth": {
|
||||
"client_api_keys": ["client-secret"],
|
||||
"node_api_keys": ["node-secret"],
|
||||
},
|
||||
"models": {
|
||||
"writer": {
|
||||
"type": "proxy",
|
||||
"openai_model_name": "writer",
|
||||
"proxy_url": "http://127.0.0.1:8012",
|
||||
"defaults": {"temperature": 0.6},
|
||||
},
|
||||
"reviewer": {
|
||||
"type": "discovered",
|
||||
"openai_model_name": "reviewer",
|
||||
"role": "reviewer",
|
||||
"strategy": "round_robin",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _create_gateway_app(tmp_path: Path):
|
||||
from rolemesh_gateway.main import create_app
|
||||
|
||||
return create_app(
|
||||
_write_gateway_config(tmp_path / "models.yaml"),
|
||||
registry_path=tmp_path / "registry.json",
|
||||
)
|
||||
|
||||
|
||||
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
return await client.request(method, path, **kwargs)
|
||||
|
||||
|
||||
def test_models_requires_client_auth(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
authorized = asyncio.run(
|
||||
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
||||
)
|
||||
assert authorized.status_code == 200
|
||||
body = authorized.json()
|
||||
assert {item["id"] for item in body["data"]} == {"writer", "reviewer"}
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_chat_completions_applies_defaults_for_proxy_model(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
calls = {}
|
||||
|
||||
async def fake_chat(base_url, payload):
|
||||
calls["base_url"] = base_url
|
||||
calls["payload"] = payload
|
||||
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"x-api-key": "client-secret"},
|
||||
json={
|
||||
"model": "writer",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert calls["base_url"] == "http://127.0.0.1:8012"
|
||||
assert calls["payload"]["temperature"] == 0.6
|
||||
assert response.json()["choices"][0]["message"]["content"] == "ok"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_discovered_model_without_registered_node_returns_503(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"x-api-key": "client-secret"},
|
||||
json={
|
||||
"model": "reviewer",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 503
|
||||
assert response.json()["error"]["code"] == "no_upstream"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
register = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/register",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-a",
|
||||
"base_url": "http://127.0.0.1:9001",
|
||||
"roles": ["reviewer"],
|
||||
"meta": {"gpu": "test"},
|
||||
},
|
||||
)
|
||||
)
|
||||
assert register.status_code == 200
|
||||
|
||||
heartbeat = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/heartbeat",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-a",
|
||||
"timestamp": 123.0,
|
||||
"status": {"healthy": True},
|
||||
"metrics": [{"device": {"id": "gpu:0"}}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert heartbeat.status_code == 200
|
||||
node = heartbeat.json()["node"]
|
||||
assert node["status"]["healthy"] is True
|
||||
assert node["status"]["timestamp"] == 123.0
|
||||
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
|
||||
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
|
||||
|
||||
|
||||
def _node_config(tmp_path: Path) -> NodeAgentConfig:
|
||||
model_path = tmp_path / "model.gguf"
|
||||
model_path.write_bytes(b"GGUF")
|
||||
return NodeAgentConfig(
|
||||
node_id="node-1",
|
||||
model_roots=[tmp_path],
|
||||
models=[ModelEntry(model_id="planner-gguf", path=model_path, roles=["planner"])],
|
||||
)
|
||||
|
||||
|
||||
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||
return await client.request(method, path, **kwargs)
|
||||
|
||||
|
||||
def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
cfg = _node_config(tmp_path)
|
||||
app = create_app(cfg)
|
||||
|
||||
async def fake_discover_devices():
|
||||
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
|
||||
|
||||
async def fake_get_metrics():
|
||||
return [
|
||||
DeviceMetrics(
|
||||
device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
|
||||
loaded_model_id="planner-gguf",
|
||||
queue_depth=1,
|
||||
)
|
||||
]
|
||||
|
||||
app.state.cuda.discover_devices = fake_discover_devices
|
||||
app.state.cuda.get_metrics = fake_get_metrics
|
||||
|
||||
response = asyncio.run(_request(app, "GET", "/v1/node/inventory"))
|
||||
|
||||
body = response.json()
|
||||
assert response.status_code == 200
|
||||
assert body["models"][0]["model_id"] == "planner-gguf"
|
||||
assert body["metrics"][0]["loaded_model_id"] == "planner-gguf"
|
||||
assert body["discovered_gguf"][0]["name"] == "model.gguf"
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
||||
|
||||
def test_chat_completions_routes_to_local_server_and_streams(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
cfg = _node_config(tmp_path)
|
||||
app = create_app(cfg)
|
||||
calls = {}
|
||||
|
||||
async def fake_discover_devices():
|
||||
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
|
||||
|
||||
async def fake_ensure_server(device, *, model_path, model_id, server_args):
|
||||
calls["device"] = device.id
|
||||
calls["model_path"] = model_path
|
||||
calls["model_id"] = model_id
|
||||
return "http://127.0.0.1:9100"
|
||||
|
||||
async def fake_chat(base_url, payload):
|
||||
calls["base_url"] = base_url
|
||||
calls["payload"] = payload
|
||||
return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
async def fake_stream(base_url, payload):
|
||||
calls["stream_base_url"] = base_url
|
||||
calls["stream_payload"] = payload
|
||||
yield b"data: first\n\n"
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
app.state.cuda.discover_devices = fake_discover_devices
|
||||
app.state.cuda.ensure_server = fake_ensure_server
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
app.state.upstream.stream_chat_completions = fake_stream
|
||||
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "planner-gguf",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
stream_response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "planner-gguf",
|
||||
"stream": True,
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["choices"][0]["message"]["content"] == "ok"
|
||||
assert calls["device"] == "gpu:0"
|
||||
assert calls["base_url"] == "http://127.0.0.1:9100"
|
||||
assert "data: first" in stream_response.text
|
||||
assert calls["stream_base_url"] == "http://127.0.0.1:9100"
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||
|
||||
|
||||
def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
|
||||
persist_path = tmp_path / "registry.json"
|
||||
registry = Registry(persist_path=persist_path)
|
||||
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-a",
|
||||
base_url="http://127.0.0.1:9001",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-b",
|
||||
base_url="http://127.0.0.1:9002",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
|
||||
assert registry.pick_node_for_role("reviewer").node_id == "node-a"
|
||||
assert registry.pick_node_for_role("reviewer").node_id == "node-b"
|
||||
|
||||
node = registry.heartbeat(
|
||||
NodeHeartbeat(
|
||||
node_id="node-a",
|
||||
timestamp=456.0,
|
||||
status={"healthy": True},
|
||||
metrics=[{"queue_depth": 1}],
|
||||
)
|
||||
)
|
||||
assert node is not None
|
||||
assert node.status["metrics"] == [{"queue_depth": 1}]
|
||||
|
||||
reloaded = Registry(persist_path=persist_path)
|
||||
assert reloaded.pick_node_for_role("reviewer").node_id == "node-a"
|
||||
assert reloaded.list_nodes()[0].status["timestamp"] == 456.0
|
||||
|
||||
|
||||
def test_registry_heartbeat_returns_none_for_unknown_node():
|
||||
registry = Registry()
|
||||
assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None
|
||||
Loading…
Reference in New Issue