Applying Codex changes, added test code.

This commit is contained in:
welsberr 2026-03-16 09:42:57 -04:00
parent 1908b42499
commit 87fcdaaacc
10 changed files with 413 additions and 41 deletions

1
.gitignore vendored
View File

@ -13,3 +13,4 @@ build/
.vscode/
state/registry.json
configs/models.yaml
tmp-codex/

View File

@ -25,10 +25,10 @@ def _key_id(token: str) -> str:
return token[-6:] if len(token) >= 6 else token
def require_client_auth(
async def require_client_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
x_api_key: Optional[str] = Header(default=None, alias="X-Api-Key"),
) -> Optional[Principal]:
"""Gate OpenAI-style endpoints with an API key.
@ -56,10 +56,10 @@ def require_client_auth(
return Principal(kind="client", key_id=_key_id(token))
def require_node_auth(
async def require_node_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
x_rolemesh_node_key: Optional[str] = Header(default=None, alias="X-RoleMesh-Node-Key"),
) -> Optional[Principal]:
"""Gate node registration/heartbeat endpoints with a node key.

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from contextlib import asynccontextmanager
import logging
import os
from pathlib import Path
from fastapi import FastAPI
from rolemesh_gateway.config import load_config
@ -24,20 +24,30 @@ def _get_logger() -> logging.Logger:
return logger
def create_app() -> FastAPI:
app = FastAPI(title="RoleMesh Gateway", version="0.1.0")
cfg_path = os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
def create_app(
config_path: str | Path | None = None,
registry_path: str | Path | None = None,
) -> FastAPI:
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
cfg = load_config(cfg_path)
registry_path = os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
registry = Registry(persist_path=Path(registry_path))
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
registry = Registry(persist_path=Path(resolved_registry_path))
upstream = UpstreamClient(
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),
read_timeout_s=float(os.environ.get("ROLE_MESH_READ_TIMEOUT_S", "600")),
)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
yield
finally:
await upstream.close()
app = FastAPI(title="RoleMesh Gateway", version="0.1.0", lifespan=lifespan)
app.state.cfg = cfg
app.state.registry = registry
app.state.upstream = upstream
@ -45,10 +55,6 @@ def create_app() -> FastAPI:
app.include_router(openai_router)
@app.on_event("shutdown")
async def _shutdown():
await upstream.close()
return app

View File

@ -22,6 +22,8 @@ class NodeHeartbeat(BaseModel):
timestamp: float = Field(default_factory=lambda: time.time())
status: Dict[str, Any] = Field(default_factory=dict)
metrics: List[Dict[str, Any]] = Field(default_factory=list)
class RegisteredNode(BaseModel):
node_id: str
base_url: HttpUrl
@ -85,11 +87,18 @@ class Registry:
self._save()
return node
def heartbeat(self, node_id: str) -> None:
n = self._nodes.get(node_id)
if n:
def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]:
n = self._nodes.get(hb.node_id)
if not n:
return None
n.status = {
**hb.status,
"timestamp": hb.timestamp,
"metrics": hb.metrics,
}
n.last_seen = time.time()
self._save()
return n
def list_nodes(self) -> List[RegisteredNode]:
return list(self._nodes.values())

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import json
from typing import Any, AsyncIterator, Dict, Optional
import httpx
@ -18,10 +17,17 @@ def _timeout(connect_s: float, read_s: float) -> httpx.Timeout:
class UpstreamClient:
def __init__(self, connect_timeout_s: float = 10.0, read_timeout_s: float = 600.0) -> None:
self._client = httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
def __init__(
self,
connect_timeout_s: float = 10.0,
read_timeout_s: float = 600.0,
client: httpx.AsyncClient | None = None,
) -> None:
self._owns_client = client is None
self._client = client or httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
async def close(self) -> None:
if self._owns_client:
await self._client.aclose()
async def get_models(self, base_url: str) -> Dict[str, Any]:
@ -58,3 +64,7 @@ class UpstreamClient:
yield chunk
except httpx.RequestError as e:
raise UpstreamError(f"Upstream unreachable: {e!s}") from e
async def stream_chat_completions(self, base_url: str, payload: Dict[str, Any]) -> AsyncIterator[bytes]:
async for chunk in self.chat_completions_stream(base_url, payload):
yield chunk

View File

@ -1,9 +1,9 @@
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager, suppress
import time
import uuid
from typing import Any, Dict, Optional
from typing import Any, Dict
import httpx
from fastapi import FastAPI, Request
@ -24,30 +24,40 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS
def create_app(cfg: NodeAgentConfig) -> FastAPI:
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0")
http = httpx.AsyncClient(
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
)
upstream = UpstreamClient(client=http)
cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
@asynccontextmanager
async def lifespan(app: FastAPI):
heartbeat_task: asyncio.Task[None] | None = None
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
try:
yield
finally:
if heartbeat_task is not None:
heartbeat_task.cancel()
with suppress(asyncio.CancelledError):
await heartbeat_task
await http.aclose()
await cuda.shutdown()
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan)
app.state.cfg = cfg
app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0))
app.state.upstream = UpstreamClient(client=app.state.http)
app.state.http = http
app.state.upstream = upstream
# Adapters
app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
app.state.cuda = cuda
# State: role -> (device, model)
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
@app.on_event("startup")
async def _startup() -> None:
# optional: dispatcher registration loop
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
asyncio.create_task(_heartbeat_loop(app))
@app.on_event("shutdown")
async def _shutdown() -> None:
await app.state.http.aclose()
await app.state.cuda.shutdown()
@app.get("/health")
async def health() -> Dict[str, Any]:
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}

11
tests/conftest.py Normal file
View File

@ -0,0 +1,11 @@
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
SRC = ROOT / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))

158
tests/test_gateway.py Normal file
View File

@ -0,0 +1,158 @@
from __future__ import annotations
import asyncio
from pathlib import Path
import httpx
import yaml
def _write_gateway_config(path: Path) -> Path:
path.write_text(
yaml.safe_dump(
{
"version": 1,
"default_model": "writer",
"auth": {
"client_api_keys": ["client-secret"],
"node_api_keys": ["node-secret"],
},
"models": {
"writer": {
"type": "proxy",
"openai_model_name": "writer",
"proxy_url": "http://127.0.0.1:8012",
"defaults": {"temperature": 0.6},
},
"reviewer": {
"type": "discovered",
"openai_model_name": "reviewer",
"role": "reviewer",
"strategy": "round_robin",
},
},
}
)
)
return path
def _create_gateway_app(tmp_path: Path):
from rolemesh_gateway.main import create_app
return create_app(
_write_gateway_config(tmp_path / "models.yaml"),
registry_path=tmp_path / "registry.json",
)
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
return await client.request(method, path, **kwargs)
def test_models_requires_client_auth(tmp_path):
app = _create_gateway_app(tmp_path)
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
assert unauthorized.status_code == 401
authorized = asyncio.run(
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
)
assert authorized.status_code == 200
body = authorized.json()
assert {item["id"] for item in body["data"]} == {"writer", "reviewer"}
asyncio.run(app.state.upstream.close())
def test_chat_completions_applies_defaults_for_proxy_model(tmp_path):
app = _create_gateway_app(tmp_path)
calls = {}
async def fake_chat(base_url, payload):
calls["base_url"] = base_url
calls["payload"] = payload
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
app.state.upstream.chat_completions = fake_chat
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
headers={"x-api-key": "client-secret"},
json={
"model": "writer",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
assert response.status_code == 200
assert calls["base_url"] == "http://127.0.0.1:8012"
assert calls["payload"]["temperature"] == 0.6
assert response.json()["choices"][0]["message"]["content"] == "ok"
asyncio.run(app.state.upstream.close())
def test_discovered_model_without_registered_node_returns_503(tmp_path):
app = _create_gateway_app(tmp_path)
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
headers={"x-api-key": "client-secret"},
json={
"model": "reviewer",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
assert response.status_code == 503
assert response.json()["error"]["code"] == "no_upstream"
asyncio.run(app.state.upstream.close())
def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
app = _create_gateway_app(tmp_path)
register = asyncio.run(
_request(
app,
"POST",
"/v1/nodes/register",
headers={"x-rolemesh-node-key": "node-secret"},
json={
"node_id": "node-a",
"base_url": "http://127.0.0.1:9001",
"roles": ["reviewer"],
"meta": {"gpu": "test"},
},
)
)
assert register.status_code == 200
heartbeat = asyncio.run(
_request(
app,
"POST",
"/v1/nodes/heartbeat",
headers={"x-rolemesh-node-key": "node-secret"},
json={
"node_id": "node-a",
"timestamp": 123.0,
"status": {"healthy": True},
"metrics": [{"device": {"id": "gpu:0"}}],
},
)
)
assert heartbeat.status_code == 200
node = heartbeat.json()["node"]
assert node["status"]["healthy"] is True
assert node["status"]["timestamp"] == 123.0
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
asyncio.run(app.state.upstream.close())

121
tests/test_node_agent.py Normal file
View File

@ -0,0 +1,121 @@
from __future__ import annotations
import asyncio
from pathlib import Path
import httpx
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
def _node_config(tmp_path: Path) -> NodeAgentConfig:
model_path = tmp_path / "model.gguf"
model_path.write_bytes(b"GGUF")
return NodeAgentConfig(
node_id="node-1",
model_roots=[tmp_path],
models=[ModelEntry(model_id="planner-gguf", path=model_path, roles=["planner"])],
)
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
return await client.request(method, path, **kwargs)
def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path):
from rolemesh_node_agent.main import create_app
cfg = _node_config(tmp_path)
app = create_app(cfg)
async def fake_discover_devices():
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
async def fake_get_metrics():
return [
DeviceMetrics(
device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
loaded_model_id="planner-gguf",
queue_depth=1,
)
]
app.state.cuda.discover_devices = fake_discover_devices
app.state.cuda.get_metrics = fake_get_metrics
response = asyncio.run(_request(app, "GET", "/v1/node/inventory"))
body = response.json()
assert response.status_code == 200
assert body["models"][0]["model_id"] == "planner-gguf"
assert body["metrics"][0]["loaded_model_id"] == "planner-gguf"
assert body["discovered_gguf"][0]["name"] == "model.gguf"
asyncio.run(app.state.http.aclose())
def test_chat_completions_routes_to_local_server_and_streams(tmp_path):
from rolemesh_node_agent.main import create_app
cfg = _node_config(tmp_path)
app = create_app(cfg)
calls = {}
async def fake_discover_devices():
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
async def fake_ensure_server(device, *, model_path, model_id, server_args):
calls["device"] = device.id
calls["model_path"] = model_path
calls["model_id"] = model_id
return "http://127.0.0.1:9100"
async def fake_chat(base_url, payload):
calls["base_url"] = base_url
calls["payload"] = payload
return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
async def fake_stream(base_url, payload):
calls["stream_base_url"] = base_url
calls["stream_payload"] = payload
yield b"data: first\n\n"
yield b"data: [DONE]\n\n"
app.state.cuda.discover_devices = fake_discover_devices
app.state.cuda.ensure_server = fake_ensure_server
app.state.upstream.chat_completions = fake_chat
app.state.upstream.stream_chat_completions = fake_stream
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
json={
"model": "planner-gguf",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
stream_response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
json={
"model": "planner-gguf",
"stream": True,
"messages": [{"role": "user", "content": "hello"}],
},
)
)
assert response.status_code == 200
assert response.json()["choices"][0]["message"]["content"] == "ok"
assert calls["device"] == "gpu:0"
assert calls["base_url"] == "http://127.0.0.1:9100"
assert "data: first" in stream_response.text
assert calls["stream_base_url"] == "http://127.0.0.1:9100"
asyncio.run(app.state.http.aclose())

46
tests/test_registry.py Normal file
View File

@ -0,0 +1,46 @@
from __future__ import annotations
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
persist_path = tmp_path / "registry.json"
registry = Registry(persist_path=persist_path)
registry.register(
NodeRegistration(
node_id="node-a",
base_url="http://127.0.0.1:9001",
roles=["reviewer"],
)
)
registry.register(
NodeRegistration(
node_id="node-b",
base_url="http://127.0.0.1:9002",
roles=["reviewer"],
)
)
assert registry.pick_node_for_role("reviewer").node_id == "node-a"
assert registry.pick_node_for_role("reviewer").node_id == "node-b"
node = registry.heartbeat(
NodeHeartbeat(
node_id="node-a",
timestamp=456.0,
status={"healthy": True},
metrics=[{"queue_depth": 1}],
)
)
assert node is not None
assert node.status["metrics"] == [{"queue_depth": 1}]
reloaded = Registry(persist_path=persist_path)
assert reloaded.pick_node_for_role("reviewer").node_id == "node-a"
assert reloaded.list_nodes()[0].status["timestamp"] == 456.0
def test_registry_heartbeat_returns_none_for_unknown_node():
registry = Registry()
assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None