Applying Codex changes, added test code.
This commit is contained in:
parent
1908b42499
commit
87fcdaaacc
|
|
@ -13,3 +13,4 @@ build/
|
||||||
.vscode/
|
.vscode/
|
||||||
state/registry.json
|
state/registry.json
|
||||||
configs/models.yaml
|
configs/models.yaml
|
||||||
|
tmp-codex/
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,10 @@ def _key_id(token: str) -> str:
|
||||||
return token[-6:] if len(token) >= 6 else token
|
return token[-6:] if len(token) >= 6 else token
|
||||||
|
|
||||||
|
|
||||||
def require_client_auth(
|
async def require_client_auth(
|
||||||
request: Request,
|
request: Request,
|
||||||
authorization: Optional[str] = Header(default=None),
|
authorization: Optional[str] = Header(default=None),
|
||||||
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
|
x_api_key: Optional[str] = Header(default=None, alias="X-Api-Key"),
|
||||||
) -> Optional[Principal]:
|
) -> Optional[Principal]:
|
||||||
"""Gate OpenAI-style endpoints with an API key.
|
"""Gate OpenAI-style endpoints with an API key.
|
||||||
|
|
||||||
|
|
@ -56,10 +56,10 @@ def require_client_auth(
|
||||||
return Principal(kind="client", key_id=_key_id(token))
|
return Principal(kind="client", key_id=_key_id(token))
|
||||||
|
|
||||||
|
|
||||||
def require_node_auth(
|
async def require_node_auth(
|
||||||
request: Request,
|
request: Request,
|
||||||
authorization: Optional[str] = Header(default=None),
|
authorization: Optional[str] = Header(default=None),
|
||||||
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
|
x_rolemesh_node_key: Optional[str] = Header(default=None, alias="X-RoleMesh-Node-Key"),
|
||||||
) -> Optional[Principal]:
|
) -> Optional[Principal]:
|
||||||
"""Gate node registration/heartbeat endpoints with a node key.
|
"""Gate node registration/heartbeat endpoints with a node key.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from rolemesh_gateway.config import load_config
|
from rolemesh_gateway.config import load_config
|
||||||
|
|
@ -24,20 +24,30 @@ def _get_logger() -> logging.Logger:
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app(
|
||||||
app = FastAPI(title="RoleMesh Gateway", version="0.1.0")
|
config_path: str | Path | None = None,
|
||||||
|
registry_path: str | Path | None = None,
|
||||||
cfg_path = os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
) -> FastAPI:
|
||||||
|
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
||||||
cfg = load_config(cfg_path)
|
cfg = load_config(cfg_path)
|
||||||
|
|
||||||
registry_path = os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
||||||
registry = Registry(persist_path=Path(registry_path))
|
registry = Registry(persist_path=Path(resolved_registry_path))
|
||||||
|
|
||||||
upstream = UpstreamClient(
|
upstream = UpstreamClient(
|
||||||
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),
|
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),
|
||||||
read_timeout_s=float(os.environ.get("ROLE_MESH_READ_TIMEOUT_S", "600")),
|
read_timeout_s=float(os.environ.get("ROLE_MESH_READ_TIMEOUT_S", "600")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
await upstream.close()
|
||||||
|
|
||||||
|
app = FastAPI(title="RoleMesh Gateway", version="0.1.0", lifespan=lifespan)
|
||||||
|
|
||||||
app.state.cfg = cfg
|
app.state.cfg = cfg
|
||||||
app.state.registry = registry
|
app.state.registry = registry
|
||||||
app.state.upstream = upstream
|
app.state.upstream = upstream
|
||||||
|
|
@ -45,10 +55,6 @@ def create_app() -> FastAPI:
|
||||||
|
|
||||||
app.include_router(openai_router)
|
app.include_router(openai_router)
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def _shutdown():
|
|
||||||
await upstream.close()
|
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ class NodeHeartbeat(BaseModel):
|
||||||
timestamp: float = Field(default_factory=lambda: time.time())
|
timestamp: float = Field(default_factory=lambda: time.time())
|
||||||
status: Dict[str, Any] = Field(default_factory=dict)
|
status: Dict[str, Any] = Field(default_factory=dict)
|
||||||
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class RegisteredNode(BaseModel):
|
class RegisteredNode(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
base_url: HttpUrl
|
base_url: HttpUrl
|
||||||
|
|
@ -85,11 +87,18 @@ class Registry:
|
||||||
self._save()
|
self._save()
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def heartbeat(self, node_id: str) -> None:
|
def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]:
|
||||||
n = self._nodes.get(node_id)
|
n = self._nodes.get(hb.node_id)
|
||||||
if n:
|
if not n:
|
||||||
n.last_seen = time.time()
|
return None
|
||||||
self._save()
|
n.status = {
|
||||||
|
**hb.status,
|
||||||
|
"timestamp": hb.timestamp,
|
||||||
|
"metrics": hb.metrics,
|
||||||
|
}
|
||||||
|
n.last_seen = time.time()
|
||||||
|
self._save()
|
||||||
|
return n
|
||||||
|
|
||||||
def list_nodes(self) -> List[RegisteredNode]:
|
def list_nodes(self) -> List[RegisteredNode]:
|
||||||
return list(self._nodes.values())
|
return list(self._nodes.values())
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncIterator, Dict, Optional
|
from typing import Any, AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -18,11 +17,18 @@ def _timeout(connect_s: float, read_s: float) -> httpx.Timeout:
|
||||||
|
|
||||||
|
|
||||||
class UpstreamClient:
|
class UpstreamClient:
|
||||||
def __init__(self, connect_timeout_s: float = 10.0, read_timeout_s: float = 600.0) -> None:
|
def __init__(
|
||||||
self._client = httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
|
self,
|
||||||
|
connect_timeout_s: float = 10.0,
|
||||||
|
read_timeout_s: float = 600.0,
|
||||||
|
client: httpx.AsyncClient | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._owns_client = client is None
|
||||||
|
self._client = client or httpx.AsyncClient(timeout=_timeout(connect_timeout_s, read_timeout_s))
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await self._client.aclose()
|
if self._owns_client:
|
||||||
|
await self._client.aclose()
|
||||||
|
|
||||||
async def get_models(self, base_url: str) -> Dict[str, Any]:
|
async def get_models(self, base_url: str) -> Dict[str, Any]:
|
||||||
url = base_url.rstrip("/") + "/v1/models"
|
url = base_url.rstrip("/") + "/v1/models"
|
||||||
|
|
@ -58,3 +64,7 @@ class UpstreamClient:
|
||||||
yield chunk
|
yield chunk
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
raise UpstreamError(f"Upstream unreachable: {e!s}") from e
|
raise UpstreamError(f"Upstream unreachable: {e!s}") from e
|
||||||
|
|
||||||
|
async def stream_chat_completions(self, base_url: str, payload: Dict[str, Any]) -> AsyncIterator[bytes]:
|
||||||
|
async for chunk in self.chat_completions_stream(base_url, payload):
|
||||||
|
yield chunk
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from contextlib import asynccontextmanager, suppress
|
||||||
import time
|
import time
|
||||||
import uuid
|
from typing import Any, Dict
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
|
|
@ -24,30 +24,40 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS
|
||||||
|
|
||||||
|
|
||||||
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
||||||
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0")
|
http = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
|
||||||
|
)
|
||||||
|
upstream = UpstreamClient(client=http)
|
||||||
|
cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
heartbeat_task: asyncio.Task[None] | None = None
|
||||||
|
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||||
|
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if heartbeat_task is not None:
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await heartbeat_task
|
||||||
|
await http.aclose()
|
||||||
|
await cuda.shutdown()
|
||||||
|
|
||||||
|
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0", lifespan=lifespan)
|
||||||
|
|
||||||
app.state.cfg = cfg
|
app.state.cfg = cfg
|
||||||
app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0))
|
app.state.http = http
|
||||||
app.state.upstream = UpstreamClient(client=app.state.http)
|
app.state.upstream = upstream
|
||||||
|
|
||||||
# Adapters
|
# Adapters
|
||||||
app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
app.state.cuda = cuda
|
||||||
|
|
||||||
# State: role -> (device, model)
|
# State: role -> (device, model)
|
||||||
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
||||||
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def _startup() -> None:
|
|
||||||
# optional: dispatcher registration loop
|
|
||||||
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
|
||||||
asyncio.create_task(_heartbeat_loop(app))
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def _shutdown() -> None:
|
|
||||||
await app.state.http.aclose()
|
|
||||||
await app.state.cuda.shutdown()
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Dict[str, Any]:
|
async def health() -> Dict[str, Any]:
|
||||||
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
|
||||||
|
if str(SRC) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC))
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
def _write_gateway_config(path: Path) -> Path:
|
||||||
|
path.write_text(
|
||||||
|
yaml.safe_dump(
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"default_model": "writer",
|
||||||
|
"auth": {
|
||||||
|
"client_api_keys": ["client-secret"],
|
||||||
|
"node_api_keys": ["node-secret"],
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"writer": {
|
||||||
|
"type": "proxy",
|
||||||
|
"openai_model_name": "writer",
|
||||||
|
"proxy_url": "http://127.0.0.1:8012",
|
||||||
|
"defaults": {"temperature": 0.6},
|
||||||
|
},
|
||||||
|
"reviewer": {
|
||||||
|
"type": "discovered",
|
||||||
|
"openai_model_name": "reviewer",
|
||||||
|
"role": "reviewer",
|
||||||
|
"strategy": "round_robin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _create_gateway_app(tmp_path: Path):
|
||||||
|
from rolemesh_gateway.main import create_app
|
||||||
|
|
||||||
|
return create_app(
|
||||||
|
_write_gateway_config(tmp_path / "models.yaml"),
|
||||||
|
registry_path=tmp_path / "registry.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
||||||
|
transport = httpx.ASGITransport(app=app)
|
||||||
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||||
|
return await client.request(method, path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_requires_client_auth(tmp_path):
|
||||||
|
app = _create_gateway_app(tmp_path)
|
||||||
|
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
|
||||||
|
assert unauthorized.status_code == 401
|
||||||
|
|
||||||
|
authorized = asyncio.run(
|
||||||
|
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
||||||
|
)
|
||||||
|
assert authorized.status_code == 200
|
||||||
|
body = authorized.json()
|
||||||
|
assert {item["id"] for item in body["data"]} == {"writer", "reviewer"}
|
||||||
|
asyncio.run(app.state.upstream.close())
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_applies_defaults_for_proxy_model(tmp_path):
|
||||||
|
app = _create_gateway_app(tmp_path)
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
async def fake_chat(base_url, payload):
|
||||||
|
calls["base_url"] = base_url
|
||||||
|
calls["payload"] = payload
|
||||||
|
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||||
|
|
||||||
|
app.state.upstream.chat_completions = fake_chat
|
||||||
|
|
||||||
|
response = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
headers={"x-api-key": "client-secret"},
|
||||||
|
json={
|
||||||
|
"model": "writer",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert calls["base_url"] == "http://127.0.0.1:8012"
|
||||||
|
assert calls["payload"]["temperature"] == 0.6
|
||||||
|
assert response.json()["choices"][0]["message"]["content"] == "ok"
|
||||||
|
asyncio.run(app.state.upstream.close())
|
||||||
|
|
||||||
|
|
||||||
|
def test_discovered_model_without_registered_node_returns_503(tmp_path):
|
||||||
|
app = _create_gateway_app(tmp_path)
|
||||||
|
response = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
headers={"x-api-key": "client-secret"},
|
||||||
|
json={
|
||||||
|
"model": "reviewer",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 503
|
||||||
|
assert response.json()["error"]["code"] == "no_upstream"
|
||||||
|
asyncio.run(app.state.upstream.close())
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
|
||||||
|
app = _create_gateway_app(tmp_path)
|
||||||
|
register = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/nodes/register",
|
||||||
|
headers={"x-rolemesh-node-key": "node-secret"},
|
||||||
|
json={
|
||||||
|
"node_id": "node-a",
|
||||||
|
"base_url": "http://127.0.0.1:9001",
|
||||||
|
"roles": ["reviewer"],
|
||||||
|
"meta": {"gpu": "test"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert register.status_code == 200
|
||||||
|
|
||||||
|
heartbeat = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/nodes/heartbeat",
|
||||||
|
headers={"x-rolemesh-node-key": "node-secret"},
|
||||||
|
json={
|
||||||
|
"node_id": "node-a",
|
||||||
|
"timestamp": 123.0,
|
||||||
|
"status": {"healthy": True},
|
||||||
|
"metrics": [{"device": {"id": "gpu:0"}}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert heartbeat.status_code == 200
|
||||||
|
node = heartbeat.json()["node"]
|
||||||
|
assert node["status"]["healthy"] is True
|
||||||
|
assert node["status"]["timestamp"] == 123.0
|
||||||
|
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
|
||||||
|
asyncio.run(app.state.upstream.close())
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
|
||||||
|
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _node_config(tmp_path: Path) -> NodeAgentConfig:
|
||||||
|
model_path = tmp_path / "model.gguf"
|
||||||
|
model_path.write_bytes(b"GGUF")
|
||||||
|
return NodeAgentConfig(
|
||||||
|
node_id="node-1",
|
||||||
|
model_roots=[tmp_path],
|
||||||
|
models=[ModelEntry(model_id="planner-gguf", path=model_path, roles=["planner"])],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
||||||
|
transport = httpx.ASGITransport(app=app)
|
||||||
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
||||||
|
return await client.request(method, path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path):
|
||||||
|
from rolemesh_node_agent.main import create_app
|
||||||
|
|
||||||
|
cfg = _node_config(tmp_path)
|
||||||
|
app = create_app(cfg)
|
||||||
|
|
||||||
|
async def fake_discover_devices():
|
||||||
|
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
|
||||||
|
|
||||||
|
async def fake_get_metrics():
|
||||||
|
return [
|
||||||
|
DeviceMetrics(
|
||||||
|
device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
|
||||||
|
loaded_model_id="planner-gguf",
|
||||||
|
queue_depth=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
app.state.cuda.discover_devices = fake_discover_devices
|
||||||
|
app.state.cuda.get_metrics = fake_get_metrics
|
||||||
|
|
||||||
|
response = asyncio.run(_request(app, "GET", "/v1/node/inventory"))
|
||||||
|
|
||||||
|
body = response.json()
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert body["models"][0]["model_id"] == "planner-gguf"
|
||||||
|
assert body["metrics"][0]["loaded_model_id"] == "planner-gguf"
|
||||||
|
assert body["discovered_gguf"][0]["name"] == "model.gguf"
|
||||||
|
asyncio.run(app.state.http.aclose())
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completions_routes_to_local_server_and_streams(tmp_path):
|
||||||
|
from rolemesh_node_agent.main import create_app
|
||||||
|
|
||||||
|
cfg = _node_config(tmp_path)
|
||||||
|
app = create_app(cfg)
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
async def fake_discover_devices():
|
||||||
|
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
|
||||||
|
|
||||||
|
async def fake_ensure_server(device, *, model_path, model_id, server_args):
|
||||||
|
calls["device"] = device.id
|
||||||
|
calls["model_path"] = model_path
|
||||||
|
calls["model_id"] = model_id
|
||||||
|
return "http://127.0.0.1:9100"
|
||||||
|
|
||||||
|
async def fake_chat(base_url, payload):
|
||||||
|
calls["base_url"] = base_url
|
||||||
|
calls["payload"] = payload
|
||||||
|
return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||||
|
|
||||||
|
async def fake_stream(base_url, payload):
|
||||||
|
calls["stream_base_url"] = base_url
|
||||||
|
calls["stream_payload"] = payload
|
||||||
|
yield b"data: first\n\n"
|
||||||
|
yield b"data: [DONE]\n\n"
|
||||||
|
|
||||||
|
app.state.cuda.discover_devices = fake_discover_devices
|
||||||
|
app.state.cuda.ensure_server = fake_ensure_server
|
||||||
|
app.state.upstream.chat_completions = fake_chat
|
||||||
|
app.state.upstream.stream_chat_completions = fake_stream
|
||||||
|
|
||||||
|
response = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "planner-gguf",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
stream_response = asyncio.run(
|
||||||
|
_request(
|
||||||
|
app,
|
||||||
|
"POST",
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "planner-gguf",
|
||||||
|
"stream": True,
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["choices"][0]["message"]["content"] == "ok"
|
||||||
|
assert calls["device"] == "gpu:0"
|
||||||
|
assert calls["base_url"] == "http://127.0.0.1:9100"
|
||||||
|
assert "data: first" in stream_response.text
|
||||||
|
assert calls["stream_base_url"] == "http://127.0.0.1:9100"
|
||||||
|
asyncio.run(app.state.http.aclose())
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
|
||||||
|
persist_path = tmp_path / "registry.json"
|
||||||
|
registry = Registry(persist_path=persist_path)
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
NodeRegistration(
|
||||||
|
node_id="node-a",
|
||||||
|
base_url="http://127.0.0.1:9001",
|
||||||
|
roles=["reviewer"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
registry.register(
|
||||||
|
NodeRegistration(
|
||||||
|
node_id="node-b",
|
||||||
|
base_url="http://127.0.0.1:9002",
|
||||||
|
roles=["reviewer"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert registry.pick_node_for_role("reviewer").node_id == "node-a"
|
||||||
|
assert registry.pick_node_for_role("reviewer").node_id == "node-b"
|
||||||
|
|
||||||
|
node = registry.heartbeat(
|
||||||
|
NodeHeartbeat(
|
||||||
|
node_id="node-a",
|
||||||
|
timestamp=456.0,
|
||||||
|
status={"healthy": True},
|
||||||
|
metrics=[{"queue_depth": 1}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert node is not None
|
||||||
|
assert node.status["metrics"] == [{"queue_depth": 1}]
|
||||||
|
|
||||||
|
reloaded = Registry(persist_path=persist_path)
|
||||||
|
assert reloaded.pick_node_for_role("reviewer").node_id == "node-a"
|
||||||
|
assert reloaded.list_nodes()[0].status["timestamp"] == 456.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_heartbeat_returns_none_for_unknown_node():
|
||||||
|
registry = Registry()
|
||||||
|
assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None
|
||||||
Loading…
Reference in New Issue