diff --git a/.gitignore b/.gitignore index e6a1fb9..3ffff1e 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ build/ .vscode/ state/registry.json configs/models.yaml +tmp-codex/ diff --git a/src/rolemesh_gateway/auth.py b/src/rolemesh_gateway/auth.py index 65b6cae..c409771 100644 --- a/src/rolemesh_gateway/auth.py +++ b/src/rolemesh_gateway/auth.py @@ -25,10 +25,10 @@ def _key_id(token: str) -> str: return token[-6:] if len(token) >= 6 else token -def require_client_auth( +async def require_client_auth( request: Request, authorization: Optional[str] = Header(default=None), - x_api_key: Optional[str] = Header(default=None, convert_underscores=False), + x_api_key: Optional[str] = Header(default=None, alias="X-Api-Key"), ) -> Optional[Principal]: """Gate OpenAI-style endpoints with an API key. @@ -56,10 +56,10 @@ def require_client_auth( return Principal(kind="client", key_id=_key_id(token)) -def require_node_auth( +async def require_node_auth( request: Request, authorization: Optional[str] = Header(default=None), - x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False), + x_rolemesh_node_key: Optional[str] = Header(default=None, alias="X-RoleMesh-Node-Key"), ) -> Optional[Principal]: """Gate node registration/heartbeat endpoints with a node key. diff --git a/src/rolemesh_gateway/main.py b/src/rolemesh_gateway/main.py index 2a8eee4..5f3e281 100644 --- a/src/rolemesh_gateway/main.py +++ b/src/rolemesh_gateway/main.py @@ -1,9 +1,9 @@ from __future__ import annotations +from contextlib import asynccontextmanager import logging import os from pathlib import Path - from fastapi import FastAPI from rolemesh_gateway.config import load_config @@ -24,20 +24,30 @@ def _get_logger() -> logging.Logger: return logger -def create_app() -> FastAPI: - app = FastAPI(title="RoleMesh Gateway", version="0.1.0") - - cfg_path = os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml") +def create_app( + config_path: str | Path | None = None, + registry_path: str | Path | None = None, +) -> FastAPI: + cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml") cfg = load_config(cfg_path) - registry_path = os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json") - registry = Registry(persist_path=Path(registry_path)) + resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json") + registry = Registry(persist_path=Path(resolved_registry_path)) upstream = UpstreamClient( connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")), read_timeout_s=float(os.environ.get("ROLE_MESH_READ_TIMEOUT_S", "600")), ) + @asynccontextmanager + async def lifespan(app: FastAPI): + try: + yield + finally: + await upstream.close() + + app = FastAPI(title="RoleMesh Gateway", version="0.1.0", lifespan=lifespan) + app.state.cfg = cfg app.state.registry = registry app.state.upstream = upstream @@ -45,10 +55,6 @@ def create_app() -> FastAPI: app.include_router(openai_router) - @app.on_event("shutdown") - async def _shutdown(): - await upstream.close() - return app diff --git a/src/rolemesh_gateway/registry.py b/src/rolemesh_gateway/registry.py index f0f5863..8f623c0 100644 --- a/src/rolemesh_gateway/registry.py +++ b/src/rolemesh_gateway/registry.py @@ -22,6 +22,8 @@ class NodeHeartbeat(BaseModel): timestamp: float = Field(default_factory=lambda: time.time()) status: Dict[str, Any] = Field(default_factory=dict) metrics: List[Dict[str, Any]] = Field(default_factory=list) + + class RegisteredNode(BaseModel): node_id: str base_url: HttpUrl @@ -85,11 +87,18 @@ class Registry: self._save() return node - def heartbeat(self, node_id: str) -> None: - n = self._nodes.get(node_id) - if n: - n.last_seen = time.time() - self._save() + def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]: + n = self._nodes.get(hb.node_id) + if not n: + return None + n.status = { + **hb.status, + "timestamp": hb.timestamp, + "metrics": hb.metrics, + } + n.last_seen = time.time() + self._save() + return n def list_nodes(self) -> List[RegisteredNode]: return list(self._nodes.values()) diff --git a/src/rolemesh_gateway/upstream.py b/src/rolemesh_gateway/upstream.py index dcfeaf6..9f8dfc7 100644 --- a/src/rolemesh_gateway/upstream.py +++ b/src/rolemesh_gateway/upstream.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from typing import Any, AsyncIterator, Dict, Optional import httpx @@ -18,11 +17,18 @@ def _timeout(connect_s: float, read_s: float) -> httpx.Timeout: class UpstreamClient: - def __init__(self, connect_timeout_s: float = 10.0, read_timeout_s: float = 600.0) -> None: - self._client = httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s)) + def __init__( + self, + connect_timeout_s: float = 10.0, + read_timeout_s: float = 600.0, + client: httpx.AsyncClient | None = None, + ) -> None: + self._owns_client = client is None + self._client = client or httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s)) async def close(self) -> None: - await self._client.aclose() + if self._owns_client: + await self._client.aclose() async def get_models(self, base_url: str) -> Dict[str, Any]: url = base_url.rstrip("/") + "/v1/models" @@ -58,3 +64,7 @@ class UpstreamClient: yield chunk except httpx.RequestError as e: raise UpstreamError(f"Upstream unreachable: {e!s}") from e + + async def stream_chat_completions(self, base_url: str, payload: Dict[str, Any]) -> AsyncIterator[bytes]: + async for chunk in self.chat_completions_stream(base_url, payload): + yield chunk diff --git a/src/rolemesh_node_agent/main.py b/src/rolemesh_node_agent/main.py index e909c7e..3348e69 100644 --- a/src/rolemesh_node_agent/main.py +++ b/src/rolemesh_node_agent/main.py @@ -1,9 +1,9 @@ from __future__ import annotations import asyncio +from contextlib import asynccontextmanager, suppress import time -import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict import httpx from fastapi import FastAPI, Request @@ -24,30 +24,40 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS def create_app(cfg: NodeAgentConfig) -> FastAPI: - app = FastAPI(title="RoleMesh Node Agent", version="0.1.0") + http = httpx.AsyncClient( + timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0) + ) + upstream = UpstreamClient(client=http) + cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin) + + @asynccontextmanager + async def lifespan(app: FastAPI): + heartbeat_task: asyncio.Task[None] | None = None + if cfg.dispatcher_base_url and cfg.dispatcher_roles: + heartbeat_task = asyncio.create_task(_heartbeat_loop(app)) + try: + yield + finally: + if heartbeat_task is not None: + heartbeat_task.cancel() + with suppress(asyncio.CancelledError): + await heartbeat_task + await http.aclose() + await cuda.shutdown() + + app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan) app.state.cfg = cfg - app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0)) - app.state.upstream = UpstreamClient(client=app.state.http) + app.state.http = http + app.state.upstream = upstream # Adapters - app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin) + app.state.cuda = cuda # State: role -> (device, model) # This is intentionally simple for the scaffold: pick first GPU and first matching model. app.state.role_bindings: Dict[str, Dict[str, str]] = {} - @app.on_event("startup") - async def _startup() -> None: - # optional: dispatcher registration loop - if cfg.dispatcher_base_url and cfg.dispatcher_roles: - asyncio.create_task(_heartbeat_loop(app)) - - @app.on_event("shutdown") - async def _shutdown() -> None: - await app.state.http.aclose() - await app.state.cuda.shutdown() - @app.get("/health") async def health() -> Dict[str, Any]: return {"status": "ok", "node_id": cfg.node_id, "time": time.time()} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..df150dc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +import sys +from pathlib import Path + + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + +if str(SRC) not in sys.path: + sys.path.insert(0, str(SRC)) diff --git a/tests/test_gateway.py b/tests/test_gateway.py new file mode 100644 index 0000000..022997d --- /dev/null +++ b/tests/test_gateway.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import httpx +import yaml + + +def _write_gateway_config(path: Path) -> Path: + path.write_text( + yaml.safe_dump( + { + "version": 1, + "default_model": "writer", + "auth": { + "client_api_keys": ["client-secret"], + "node_api_keys": ["node-secret"], + }, + "models": { + "writer": { + "type": "proxy", + "openai_model_name": "writer", + "proxy_url": "http://127.0.0.1:8012", + "defaults": {"temperature": 0.6}, + }, + "reviewer": { + "type": "discovered", + "openai_model_name": "reviewer", + "role": "reviewer", + "strategy": "round_robin", + }, + }, + } + ) + ) + return path + + +def _create_gateway_app(tmp_path: Path): + from rolemesh_gateway.main import create_app + + return create_app( + _write_gateway_config(tmp_path / "models.yaml"), + registry_path=tmp_path / "registry.json", + ) + + +async def _request(app, method: str, path: str, **kwargs) -> httpx.Response: + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + return await client.request(method, path, **kwargs) + + +def test_models_requires_client_auth(tmp_path): + app = _create_gateway_app(tmp_path) + unauthorized = asyncio.run(_request(app, "GET", "/v1/models")) + assert unauthorized.status_code == 401 + + authorized = asyncio.run( + _request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"}) + ) + assert authorized.status_code == 200 + body = authorized.json() + assert {item["id"] for item in body["data"]} == {"writer", "reviewer"} + asyncio.run(app.state.upstream.close()) + + +def test_chat_completions_applies_defaults_for_proxy_model(tmp_path): + app = _create_gateway_app(tmp_path) + calls = {} + + async def fake_chat(base_url, payload): + calls["base_url"] = base_url + calls["payload"] = payload + return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + + app.state.upstream.chat_completions = fake_chat + + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + headers={"x-api-key": "client-secret"}, + json={ + "model": "writer", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 200 + assert calls["base_url"] == "http://127.0.0.1:8012" + assert calls["payload"]["temperature"] == 0.6 + assert response.json()["choices"][0]["message"]["content"] == "ok" + asyncio.run(app.state.upstream.close()) + + +def test_discovered_model_without_registered_node_returns_503(tmp_path): + app = _create_gateway_app(tmp_path) + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + headers={"x-api-key": "client-secret"}, + json={ + "model": "reviewer", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 503 + assert response.json()["error"]["code"] == "no_upstream" + asyncio.run(app.state.upstream.close()) + + +def test_node_registration_and_heartbeat_update_registry_state(tmp_path): + app = _create_gateway_app(tmp_path) + register = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/register", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-a", + "base_url": "http://127.0.0.1:9001", + "roles": ["reviewer"], + "meta": {"gpu": "test"}, + }, + ) + ) + assert register.status_code == 200 + + heartbeat = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/heartbeat", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-a", + "timestamp": 123.0, + "status": {"healthy": True}, + "metrics": [{"device": {"id": "gpu:0"}}], + }, + ) + ) + + assert heartbeat.status_code == 200 + node = heartbeat.json()["node"] + assert node["status"]["healthy"] is True + assert node["status"]["timestamp"] == 123.0 + assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}] + asyncio.run(app.state.upstream.close()) diff --git a/tests/test_node_agent.py b/tests/test_node_agent.py new file mode 100644 index 0000000..59dd478 --- /dev/null +++ b/tests/test_node_agent.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +import httpx + +from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef +from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig + + +def _node_config(tmp_path: Path) -> NodeAgentConfig: + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"GGUF") + return NodeAgentConfig( + node_id="node-1", + model_roots=[tmp_path], + models=[ModelEntry(model_id="planner-gguf", path=model_path, roles=["planner"])], + ) + + +async def _request(app, method: str, path: str, **kwargs) -> httpx.Response: + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client: + return await client.request(method, path, **kwargs) + + +def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path): + from rolemesh_node_agent.main import create_app + + cfg = _node_config(tmp_path) + app = create_app(cfg) + + async def fake_discover_devices(): + return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")] + + async def fake_get_metrics(): + return [ + DeviceMetrics( + device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"), + loaded_model_id="planner-gguf", + queue_depth=1, + ) + ] + + app.state.cuda.discover_devices = fake_discover_devices + app.state.cuda.get_metrics = fake_get_metrics + + response = asyncio.run(_request(app, "GET", "/v1/node/inventory")) + + body = response.json() + assert response.status_code == 200 + assert body["models"][0]["model_id"] == "planner-gguf" + assert body["metrics"][0]["loaded_model_id"] == "planner-gguf" + assert body["discovered_gguf"][0]["name"] == "model.gguf" + asyncio.run(app.state.http.aclose()) + + +def test_chat_completions_routes_to_local_server_and_streams(tmp_path): + from rolemesh_node_agent.main import create_app + + cfg = _node_config(tmp_path) + app = create_app(cfg) + calls = {} + + async def fake_discover_devices(): + return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")] + + async def fake_ensure_server(device, *, model_path, model_id, server_args): + calls["device"] = device.id + calls["model_path"] = model_path + calls["model_id"] = model_id + return "http://127.0.0.1:9100" + + async def fake_chat(base_url, payload): + calls["base_url"] = base_url + calls["payload"] = payload + return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + + async def fake_stream(base_url, payload): + calls["stream_base_url"] = base_url + calls["stream_payload"] = payload + yield b"data: first\n\n" + yield b"data: [DONE]\n\n" + + app.state.cuda.discover_devices = fake_discover_devices + app.state.cuda.ensure_server = fake_ensure_server + app.state.upstream.chat_completions = fake_chat + app.state.upstream.stream_chat_completions = fake_stream + + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + json={ + "model": "planner-gguf", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + stream_response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + json={ + "model": "planner-gguf", + "stream": True, + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "ok" + assert calls["device"] == "gpu:0" + assert calls["base_url"] == "http://127.0.0.1:9100" + assert "data: first" in stream_response.text + assert calls["stream_base_url"] == "http://127.0.0.1:9100" + asyncio.run(app.state.http.aclose()) diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..2701424 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry + + +def test_registry_persists_round_robin_and_heartbeat_state(tmp_path): + persist_path = tmp_path / "registry.json" + registry = Registry(persist_path=persist_path) + + registry.register( + NodeRegistration( + node_id="node-a", + base_url="http://127.0.0.1:9001", + roles=["reviewer"], + ) + ) + registry.register( + NodeRegistration( + node_id="node-b", + base_url="http://127.0.0.1:9002", + roles=["reviewer"], + ) + ) + + assert registry.pick_node_for_role("reviewer").node_id == "node-a" + assert registry.pick_node_for_role("reviewer").node_id == "node-b" + + node = registry.heartbeat( + NodeHeartbeat( + node_id="node-a", + timestamp=456.0, + status={"healthy": True}, + metrics=[{"queue_depth": 1}], + ) + ) + assert node is not None + assert node.status["metrics"] == [{"queue_depth": 1}] + + reloaded = Registry(persist_path=persist_path) + assert reloaded.pick_node_for_role("reviewer").node_id == "node-a" + assert reloaded.list_nodes()[0].status["timestamp"] == 456.0 + + +def test_registry_heartbeat_returns_none_for_unknown_node(): + registry = Registry() + assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None