Addressing reliability issues from Didactopus use.

This commit is contained in:
welsberr 2026-03-16 12:07:34 -04:00
parent 5befa6d7f6
commit 79923983a0
15 changed files with 820 additions and 87 deletions

View File

@ -19,6 +19,26 @@ deployments where **different machines host different models** (e.g., GPU box fo
- Robust proxying with **explicit httpx timeouts** (no “hang forever”)
- Structured logging with request IDs
## Roles Are Project-Defined
The role names in this repository are examples, not a fixed taxonomy.
- `planner`, `writer`, `coder`, and `reviewer` are only sample aliases
- you can add, remove, or rename roles per project
- a role is simply the `model` alias clients send to RoleMesh Gateway
- each role can point at any OpenAI-compatible backend that fits that project's workflow
Examples of project-specific roles:
- `researcher`
- `summarizer`
- `tool-user`
- `swe-backend`
- `swe-frontend`
- `test-writer`
- `security-reviewer`
If your workflow changes, update the `models:` section in config rather than treating the example roles as required.
## Quick Start
This is the fastest path to a working local setup.
@ -68,6 +88,8 @@ models:
Save that as `configs/models.yaml`.
You are not limited to `planner` and `writer`. Those are just placeholders for whatever roles your project needs.
### 4. Run the gateway
```bash
@ -166,6 +188,21 @@ This repository is a **preliminary scaffold**:
- Gateway proxying has been exercised live with Ollama and `llamafile`.
- Node-agent managed inference has been exercised live with `llama-server` on CUDA hardware.
## Availability Semantics
RoleMesh Gateway now distinguishes between configured aliases and currently usable aliases.
- `GET /v1/models` advertises only aliases whose upstreams are reachable right now
- unavailable aliases are reported under `rolemesh.unavailable_models`
- `GET /ready` returns `200` only when the configured `default_model` is currently usable
- `GET /health` remains a lightweight process health check and does not probe upstreams
- discovered nodes are removed from routing once they become stale
This makes the API surface more truthful for clients that rely on the advertised role list.
By default, registered nodes become stale after `30` seconds without a fresh heartbeat or registration event.
You can change that with `ROLE_MESH_NODE_STALE_AFTER_S`.
## License
MIT. See `LICENSE`.

View File

@ -17,6 +17,8 @@ auth:
# Models may be:
# - type: proxy (static URL to an OpenAI-compatible upstream)
# - type: discovered (resolved from registered nodes by role)
# The names under "models" are project-defined role aliases, not a fixed built-in list.
# Rename or replace planner/writer/coder/reviewer with whatever your workflow needs.
models:
planner:
type: proxy

View File

@ -36,6 +36,6 @@ This is deliberately small so you can swap it out later for something stronger:
- Auth is optional and config-driven rather than enforced by default
- No TTL/health polling
- No automatic config reload
- Round-robin selection only for discovered nodes
- Selection strategies are intentionally simple and limited to `round_robin` and `random`
These are tracked in `docs/DEPLOYMENT.md` as next steps.

View File

@ -24,6 +24,33 @@ models:
- `<alias>` is what clients pass as `model` in `/v1/chat/completions`.
- `openai_model_name` is the model id returned by `/v1/models` (usually same as alias).
## Roles are aliases, not a fixed list
RoleMesh Gateway does not reserve a built-in set of roles.
- The keys under `models:` are your project-specific role names
- Clients send those keys in the OpenAI `model` field
- You can rename or replace the sample roles entirely
- Different projects can use different role layouts with the same gateway
Example custom role set:
```yaml
models:
researcher:
type: proxy
openai_model_name: researcher
proxy_url: http://127.0.0.1:8011
summarizer:
type: proxy
openai_model_name: summarizer
proxy_url: http://127.0.0.1:8012
security-reviewer:
type: proxy
openai_model_name: security-reviewer
proxy_url: http://127.0.0.1:8013
```
## Proxy models
Route to a fixed upstream (any host reachable from the gateway):
@ -57,6 +84,10 @@ models:
strategy: round_robin
```
Supported discovered-node strategies:
- `round_robin`: rotate requests across fresh matching nodes
- `random`: choose a fresh matching node at random for each request
### Registering nodes
Nodes register to `POST /v1/nodes/register`:
@ -78,6 +109,21 @@ Supported headers:
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`
## Availability behavior
Configured aliases are not automatically assumed healthy.
- `GET /v1/models` probes configured upstreams and only returns aliases that are currently reachable
- unavailable aliases are included separately in `rolemesh.unavailable_models`
- `GET /ready` returns success only when the configured `default_model` is currently usable
- discovered nodes are only considered routable while they are fresh
- gateway metadata marks stale registered nodes so operators can distinguish them from healthy nodes
This is especially important when a config contains multiple optional roles but only some backends are up.
For discovered-node freshness, the gateway uses the `ROLE_MESH_NODE_STALE_AFTER_S` environment variable.
Default: `30`.
## Quick example
```yaml

View File

@ -53,6 +53,22 @@ curl -sS -X POST http://127.0.0.1:8000/v1/chat/completions \
}'
```
### Readiness and model advertisement
- `GET /health` only checks that the gateway process is up
- `GET /ready` checks whether the configured default route is actually usable
- `GET /v1/models` only lists aliases with a currently reachable upstream
- aliases that are configured but currently unavailable are reported in `rolemesh.unavailable_models`
- discovered nodes that have not checked in recently are marked stale and excluded from routing
### Stale node timeout
Registered nodes age out of discovered-role routing after a heartbeat timeout.
- default timeout: `30` seconds
- configure with `ROLE_MESH_NODE_STALE_AFTER_S`
- stale nodes remain visible for operators in the gateway metadata, but they no longer receive traffic
## Network binding and exposure (Step 2 hardening)
**Defaults are safe-by-default:** the gateway and node-agent CLIs default to binding on `127.0.0.1` (localhost).
@ -101,6 +117,7 @@ This scaffold supports two patterns.
- `http://10.0.0.13:8011` (planner)
- Update `proxy_url` entries to those LAN URLs, **or** use discovery:
- Set model to `type: discovered` with `role: writer`, etc.
- Choose `strategy: round_robin` or `strategy: random` per discovered alias
- Each host registers itself with the gateway.
### Minimal registration call
@ -118,5 +135,5 @@ curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
- Configure API keys for:
- inference endpoints via `auth.client_api_keys`
- node registration and heartbeat via `auth.node_api_keys`
- Add TTL and periodic health checks for registered nodes
- Tune `ROLE_MESH_NODE_STALE_AFTER_S` for your heartbeat interval and failure tolerance
- Consider mTLS if registration happens over untrusted networks

View File

@ -16,6 +16,12 @@ Model switching is handled by **restart** in the scaffold.
The agent now waits for the replacement `llama-server` to report readiness before proxying the first request.
If startup or switching takes too long, the request fails with a `503` instead of passing through a transient upstream
"Loading model" error.
Device selection is still simple, but it is no longer hard-coded to the first GPU:
- first preference: a device that already has the requested model loaded
- otherwise: the device with the most free VRAM and least queue pressure
- requests are serialized per device
- each device has a bounded pending-request limit for backpressure
## Backends
@ -41,10 +47,12 @@ Two config knobs control how long the node agent waits for a managed `llama-serv
```yaml
llama_server_startup_timeout_s: 30.0
llama_server_probe_interval_s: 0.5
max_pending_requests_per_device: 2
```
- `llama_server_startup_timeout_s`: maximum time to wait for a newly started or switched model
- `llama_server_probe_interval_s`: polling interval for readiness checks
- `max_pending_requests_per_device`: maximum in-flight plus queued requests allowed per device before new requests are rejected
The readiness probe checks the managed server's local `GET /health` and `GET /v1/models` endpoints.

View File

@ -1,9 +1,9 @@
from __future__ import annotations
import os
import asyncio
import time
import uuid
from typing import Any, Dict
from typing import Any, Dict, List
from fastapi import APIRouter, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
@ -30,28 +30,107 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
)
async def _model_status(
alias: str,
entry: ProxyModel | DiscoveredModel,
registry: Registry,
upstream: UpstreamClient,
) -> Dict[str, Any]:
if isinstance(entry, ProxyModel):
base_url = str(entry.proxy_url).rstrip("/")
try:
await upstream.get_models(base_url)
return {"alias": alias, "available": True, "base_url": base_url}
except UpstreamError as exc:
return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)}
matching_nodes = registry.nodes_for_role(entry.role, include_stale=False)
if not matching_nodes:
stale_nodes = registry.nodes_for_role(entry.role, include_stale=True)
error = "no_registered_nodes"
if stale_nodes:
error = "no_fresh_registered_nodes"
return {"alias": alias, "available": False, "role": entry.role, "error": error}
for node in matching_nodes:
base_url = str(node.base_url).rstrip("/")
try:
await upstream.get_models(base_url)
return {
"alias": alias,
"available": True,
"role": entry.role,
"base_url": base_url,
"node_id": node.node_id,
}
except UpstreamError:
continue
return {
"alias": alias,
"available": False,
"role": entry.role,
"error": "no_healthy_registered_nodes",
}
async def _collect_model_statuses(
cfg: Config,
registry: Registry,
upstream: UpstreamClient,
) -> Dict[str, Dict[str, Any]]:
statuses = await asyncio.gather(
*[
_model_status(alias, entry, registry, upstream)
for alias, entry in cfg.models.items()
]
)
return {status["alias"]: status for status in statuses}
@router.get("/v1/models")
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
cfg: Config = request.app.state.cfg
registry: Registry = request.app.state.registry
upstream: UpstreamClient = request.app.state.upstream
statuses = await _collect_model_statuses(cfg, registry, upstream)
data = []
for name, entry in cfg.models.items():
status = statuses[name]
if not status["available"]:
continue
item = {
"id": entry.openai_model_name,
"object": "model",
"owned_by": "local",
}
if isinstance(entry, ProxyModel):
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url)}
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True}
else:
item["rolemesh"] = {"type": "discovered", "role": entry.role, "strategy": entry.strategy}
item["rolemesh"] = {
"type": "discovered",
"role": entry.role,
"strategy": entry.strategy,
"available": True,
}
data.append(item)
# Expose currently registered nodes (informational)
nodes = [n.model_dump(mode="json") for n in registry.list_nodes()]
nodes = [
n.model_dump(mode="json") | {"stale": registry.is_stale(n)}
for n in registry.list_nodes(include_stale=True)
]
unavailable = [status for status in statuses.values() if not status["available"]]
return {"object": "list", "data": data, "rolemesh": {"registered_nodes": nodes}}
return {
"object": "list",
"data": data,
"rolemesh": {
"registered_nodes": nodes,
"unavailable_models": unavailable,
},
}
@router.post("/v1/chat/completions")
@ -81,7 +160,7 @@ async def chat_completions(request: Request, _=Depends(require_client_auth)) ->
if isinstance(entry, ProxyModel):
base_url = str(entry.proxy_url).rstrip("/")
elif isinstance(entry, DiscoveredModel):
node = registry.pick_node_for_role(entry.role)
node = registry.pick_node_for_role(entry.role, strategy=entry.strategy)
if not node:
return _openai_error(
f"No registered nodes available for role '{entry.role}'. "
@ -125,12 +204,33 @@ async def health() -> Dict[str, str]:
@router.get("/ready")
async def ready(request: Request) -> JSONResponse:
"""
Readiness checks for presence of config and (optionally) upstreams.
For now: verifies config loads and returns 200.
"""
cfg: Config = request.app.state.cfg
return JSONResponse(status_code=200, content={"status": "ready", "default_model": cfg.default_model})
registry: Registry = request.app.state.registry
upstream: UpstreamClient = request.app.state.upstream
statuses = await _collect_model_statuses(cfg, registry, upstream)
default_status = statuses.get(cfg.default_model)
available_aliases = [alias for alias, status in statuses.items() if status["available"]]
if default_status and default_status["available"]:
return JSONResponse(
status_code=200,
content={
"status": "ready",
"default_model": cfg.default_model,
"available_models": available_aliases,
"unavailable_models": [s for s in statuses.values() if not s["available"]],
},
)
return JSONResponse(
status_code=503,
content={
"status": "not_ready",
"default_model": cfg.default_model,
"available_models": available_aliases,
"unavailable_models": [s for s in statuses.values() if not s["available"]],
},
)
@router.post("/v1/nodes/register")

View File

@ -27,12 +27,19 @@ def _get_logger() -> logging.Logger:
def create_app(
config_path: str | Path | None = None,
registry_path: str | Path | None = None,
registry_stale_after_s: float | None = None,
) -> FastAPI:
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
cfg = load_config(cfg_path)
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
registry = Registry(persist_path=Path(resolved_registry_path))
resolved_registry_stale_after_s = registry_stale_after_s
if resolved_registry_stale_after_s is None:
resolved_registry_stale_after_s = float(os.environ.get("ROLE_MESH_NODE_STALE_AFTER_S", "30"))
registry = Registry(
persist_path=Path(resolved_registry_path),
stale_after_s=resolved_registry_stale_after_s,
)
upstream = UpstreamClient(
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import json
import random
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, HttpUrl
@ -44,10 +45,11 @@ class Registry:
- TTL + health checks
"""
def __init__(self, persist_path: Optional[Path] = None) -> None:
def __init__(self, persist_path: Optional[Path] = None, stale_after_s: float = 30.0) -> None:
self._nodes: Dict[str, RegisteredNode] = {}
self._rr_counters: Dict[str, int] = {}
self._persist_path = persist_path
self._stale_after_s = stale_after_s
if self._persist_path:
self._load()
@ -100,13 +102,29 @@ class Registry:
self._save()
return n
def list_nodes(self) -> List[RegisteredNode]:
return list(self._nodes.values())
def is_stale(self, node: RegisteredNode, now: Optional[float] = None) -> bool:
if self._stale_after_s <= 0:
return False
now = time.time() if now is None else now
return (now - node.last_seen) > self._stale_after_s
def pick_node_for_role(self, role: str) -> Optional[RegisteredNode]:
candidates = [n for n in self._nodes.values() if role in n.roles]
def list_nodes(self, *, include_stale: bool = True) -> List[RegisteredNode]:
nodes = list(self._nodes.values())
if include_stale:
return nodes
now = time.time()
return [node for node in nodes if not self.is_stale(node, now=now)]
def nodes_for_role(self, role: str, *, include_stale: bool = False) -> List[RegisteredNode]:
nodes = self.list_nodes(include_stale=include_stale)
return [node for node in nodes if role in node.roles]
def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]:
candidates = self.nodes_for_role(role, include_stale=False)
if not candidates:
return None
if strategy == "random":
return random.choice(candidates)
idx = self._rr_counters.get(role, 0) % len(candidates)
self._rr_counters[role] = idx + 1
self._save()

View File

@ -40,3 +40,4 @@ class NodeAgentConfig(BaseModel):
llama_server_bin: str = "llama-server"
llama_server_startup_timeout_s: float = 30.0
llama_server_probe_interval_s: float = 0.5
max_pending_requests_per_device: int = 2

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager, suppress
import time
from typing import Any, Dict
from typing import Any, Dict, Iterable
import httpx
from fastapi import FastAPI, Request
@ -11,9 +11,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
from .adapters.cuda import CudaAdapter, ServerStartupError
from .adapters.base import DeviceRef
from .adapters.base import DeviceMetrics, DeviceRef
from .config import NodeAgentConfig
from .inventory import discover_gguf_models
from .scheduler import AdmissionError, DeviceQueue
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
@ -23,6 +24,54 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS
)
def _merge_scheduler_metrics(
metrics: Iterable[DeviceMetrics],
queues: Dict[str, DeviceQueue],
) -> list[DeviceMetrics]:
out: list[DeviceMetrics] = []
for metric in metrics:
snapshot = queues.get(metric.device.id)
if snapshot is not None:
queue = snapshot.snapshot()
metric.queue_depth = queue.queue_depth
metric.in_flight_jobs = queue.in_flight
out.append(metric)
return out
def _select_device(
devices: Iterable[DeviceRef],
metrics: Iterable[DeviceMetrics],
*,
model_id: str,
) -> DeviceRef | None:
device_list = list(devices)
if not device_list:
return None
metrics_by_id = {metric.device.id: metric for metric in metrics}
# Reuse a device that already has the requested model loaded.
for device in device_list:
metric = metrics_by_id.get(device.id)
if metric and metric.loaded_model_id == model_id:
return device
def score(device: DeviceRef) -> tuple[float, float, float, str]:
metric = metrics_by_id.get(device.id)
if metric is None:
return (float("-inf"), 0.0, 0.0, device.id)
free_mem = 0.0
if metric.mem_total_gb is not None and metric.mem_used_gb is not None:
free_mem = metric.mem_total_gb - metric.mem_used_gb
queue_penalty = float(metric.queue_depth + metric.in_flight_jobs)
util_penalty = float(metric.utilization_pct or 0.0)
return (free_mem, -queue_penalty, -util_penalty, device.id)
return max(device_list, key=score)
def create_app(cfg: NodeAgentConfig) -> FastAPI:
http = httpx.AsyncClient(
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
@ -57,6 +106,7 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
# Adapters
app.state.cuda = cuda
app.state.device_queues: Dict[str, DeviceQueue] = {}
# State: role -> (device, model)
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
@ -69,7 +119,9 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
@app.get("/v1/node/inventory")
async def inventory() -> Dict[str, Any]:
devices = await app.state.cuda.discover_devices()
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
raw_metrics = await app.state.cuda.get_metrics()
merged_metrics = _merge_scheduler_metrics(raw_metrics, app.state.device_queues)
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in merged_metrics]
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
discovered = discover_gguf_models(cfg.model_roots)
return {
@ -104,8 +156,20 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
devices = await app.state.cuda.discover_devices()
if not devices:
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
device = devices[0]
metrics = _merge_scheduler_metrics(await app.state.cuda.get_metrics(), app.state.device_queues)
device = _select_device(devices, metrics, model_id=model_entry.model_id)
if device is None:
return _error("No eligible CUDA GPUs discovered on this node.", code="no_device", status_code=503)
queue = app.state.device_queues.setdefault(
device.id,
DeviceQueue(max_pending=cfg.max_pending_requests_per_device),
)
try:
lease = await queue.acquire()
except AdmissionError:
return _error("Device queue is full.", code="queue_full", status_code=429)
try:
try:
base_url = await app.state.cuda.ensure_server(
device,
@ -123,11 +187,17 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
return JSONResponse(status_code=200, content=out)
else:
async def gen():
try:
async for chunk in upstream.stream_chat_completions(base_url, body):
yield chunk
finally:
await lease.release()
return StreamingResponse(gen(), media_type="text/event-stream")
except UpstreamError as e:
return _error(str(e), code="upstream_error", status_code=502)
finally:
if not stream:
await lease.release()
return app

View File

@ -1,59 +1,66 @@
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass
from typing import Any, AsyncIterator, Dict, Optional
import httpx
class AdmissionError(RuntimeError):
pass
@dataclass
class Job:
job_id: str
submitted_at: float
base_url: str
body: Dict[str, Any]
stream: bool
class QueueSnapshot:
queue_depth: int
in_flight: int
class DeviceLease:
def __init__(self, queue: "DeviceQueue") -> None:
self._queue = queue
self._released = False
async def release(self) -> None:
if self._released:
return
self._released = True
await self._queue._release()
async def __aenter__(self) -> "DeviceLease":
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.release()
class DeviceQueue:
"""Simple FIFO queue per device; one worker per device for latency predictability."""
"""Serialize work per device and enforce a bounded number of pending requests."""
def __init__(self, *, client: httpx.AsyncClient, base_url: str) -> None:
self._client = client
self._base_url = base_url.rstrip("/")
self._q: asyncio.Queue[Job] = asyncio.Queue()
self._worker_task: Optional[asyncio.Task] = None
self.in_flight: int = 0
self.last_job_started_at: Optional[float] = None
def __init__(self, *, max_pending: int = 2) -> None:
self._max_pending = max_pending
self._lock = asyncio.Lock()
self._guard = asyncio.Lock()
self._queued = 0
self._in_flight = 0
def start(self) -> None:
if self._worker_task is None:
self._worker_task = asyncio.create_task(self._worker())
async def acquire(self) -> DeviceLease:
async with self._guard:
pending = self._queued + self._in_flight
if self._max_pending > 0 and pending >= self._max_pending:
raise AdmissionError("device queue is full")
self._queued += 1
async def stop(self) -> None:
if self._worker_task:
self._worker_task.cancel()
self._worker_task = None
await self._lock.acquire()
async def submit(self, job: Job) -> Job:
await self._q.put(job)
return job
async with self._guard:
self._queued -= 1
self._in_flight += 1
@property
def depth(self) -> int:
return self._q.qsize()
return DeviceLease(self)
async def _worker(self) -> None:
while True:
job = await self._q.get()
self.in_flight += 1
self.last_job_started_at = time.time()
try:
# The actual request execution is handled by the API layer (so it can stream).
# This worker exists mainly to serialize jobs per device and provide queue metrics.
await asyncio.sleep(0) # placeholder
finally:
self.in_flight -= 1
self._q.task_done()
async def _release(self) -> None:
async with self._guard:
self._in_flight -= 1
self._lock.release()
def snapshot(self) -> QueueSnapshot:
return QueueSnapshot(queue_depth=self._queued, in_flight=self._in_flight)

View File

@ -30,6 +30,12 @@ def _write_gateway_config(path: Path) -> Path:
"role": "reviewer",
"strategy": "round_robin",
},
"random-reviewer": {
"type": "discovered",
"openai_model_name": "random-reviewer",
"role": "reviewer",
"strategy": "random",
},
},
}
)
@ -43,6 +49,7 @@ def _create_gateway_app(tmp_path: Path):
return create_app(
_write_gateway_config(tmp_path / "models.yaml"),
registry_path=tmp_path / "registry.json",
registry_stale_after_s=30.0,
)
@ -54,6 +61,10 @@ async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
def test_models_requires_client_auth(tmp_path):
app = _create_gateway_app(tmp_path)
async def fake_get_models(base_url):
return {"object": "list", "data": []}
app.state.upstream.get_models = fake_get_models
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
assert unauthorized.status_code == 401
@ -62,7 +73,8 @@ def test_models_requires_client_auth(tmp_path):
)
assert authorized.status_code == 200
body = authorized.json()
assert {item["id"] for item in body["data"]} == {"writer", "reviewer"}
assert {item["id"] for item in body["data"]} == {"writer"}
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"reviewer", "random-reviewer"}
asyncio.run(app.state.upstream.close())
@ -156,3 +168,184 @@ def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
assert node["status"]["timestamp"] == 123.0
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
asyncio.run(app.state.upstream.close())
def test_models_filters_unhealthy_aliases_and_reports_them(tmp_path):
app = _create_gateway_app(tmp_path)
async def fake_get_models(base_url):
if base_url == "http://127.0.0.1:8012":
raise Exception("unexpected")
return {"object": "list", "data": []}
from rolemesh_gateway.upstream import UpstreamError
async def wrapped_get_models(base_url):
if base_url == "http://127.0.0.1:8012":
raise UpstreamError("Upstream unreachable: boom")
return await fake_get_models(base_url)
app.state.upstream.get_models = wrapped_get_models
response = asyncio.run(
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
)
assert response.status_code == 200
body = response.json()
assert body["data"] == []
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"writer", "reviewer", "random-reviewer"}
asyncio.run(app.state.upstream.close())
def test_ready_returns_503_when_default_model_is_unavailable(tmp_path):
app = _create_gateway_app(tmp_path)
from rolemesh_gateway.upstream import UpstreamError
async def fake_get_models(base_url):
raise UpstreamError("Upstream unreachable: boom")
app.state.upstream.get_models = fake_get_models
response = asyncio.run(_request(app, "GET", "/ready"))
assert response.status_code == 503
body = response.json()
assert body["status"] == "not_ready"
assert body["default_model"] == "writer"
asyncio.run(app.state.upstream.close())
def test_ready_returns_200_when_default_model_is_available(tmp_path):
app = _create_gateway_app(tmp_path)
async def fake_get_models(base_url):
if base_url == "http://127.0.0.1:8012":
return {"object": "list", "data": []}
raise Exception("unexpected upstream probe")
from rolemesh_gateway.upstream import UpstreamError
async def wrapped_get_models(base_url):
if base_url == "http://127.0.0.1:8012":
return {"object": "list", "data": []}
raise UpstreamError("no reviewer nodes")
app.state.upstream.get_models = wrapped_get_models
response = asyncio.run(_request(app, "GET", "/ready"))
assert response.status_code == 200
body = response.json()
assert body["status"] == "ready"
assert body["default_model"] == "writer"
assert body["available_models"] == ["writer"]
asyncio.run(app.state.upstream.close())
def test_stale_discovered_nodes_are_not_advertised_as_available(tmp_path):
app = _create_gateway_app(tmp_path)
register = asyncio.run(
_request(
app,
"POST",
"/v1/nodes/register",
headers={"x-rolemesh-node-key": "node-secret"},
json={
"node_id": "node-a",
"base_url": "http://127.0.0.1:9001",
"roles": ["reviewer"],
},
)
)
assert register.status_code == 200
node = app.state.registry.list_nodes(include_stale=True)[0]
node.last_seen = node.last_seen - 120
async def fake_get_models(base_url):
return {"object": "list", "data": []}
app.state.upstream.get_models = fake_get_models
models = asyncio.run(
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
)
ready = asyncio.run(_request(app, "GET", "/ready"))
models_body = models.json()
assert {item["id"] for item in models_body["data"]} == {"writer"}
assert models_body["rolemesh"]["registered_nodes"][0]["stale"] is True
unavailable = {item["alias"]: item for item in models_body["rolemesh"]["unavailable_models"]}
assert unavailable["reviewer"]["error"] == "no_fresh_registered_nodes"
assert unavailable["random-reviewer"]["error"] == "no_fresh_registered_nodes"
assert ready.status_code == 200
asyncio.run(app.state.upstream.close())
def test_discovered_model_uses_random_strategy_when_configured(tmp_path):
app = _create_gateway_app(tmp_path)
calls = {}
register_a = asyncio.run(
_request(
app,
"POST",
"/v1/nodes/register",
headers={"x-rolemesh-node-key": "node-secret"},
json={
"node_id": "node-a",
"base_url": "http://127.0.0.1:9001",
"roles": ["reviewer"],
},
)
)
register_b = asyncio.run(
_request(
app,
"POST",
"/v1/nodes/register",
headers={"x-rolemesh-node-key": "node-secret"},
json={
"node_id": "node-b",
"base_url": "http://127.0.0.1:9002",
"roles": ["reviewer"],
},
)
)
assert register_a.status_code == 200
assert register_b.status_code == 200
import random
original_choice = random.choice
random.choice = lambda candidates: candidates[-1]
async def fake_chat(base_url, payload):
calls["base_url"] = base_url
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
app.state.upstream.chat_completions = fake_chat
app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []}
try:
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
headers={"x-api-key": "client-secret"},
json={
"model": "random-reviewer",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
finally:
random.choice = original_choice
assert response.status_code == 200
assert calls["base_url"] == "http://127.0.0.1:9002"
asyncio.run(app.state.upstream.close())

View File

@ -7,6 +7,8 @@ import httpx
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
from rolemesh_node_agent.main import _merge_scheduler_metrics, _select_device
from rolemesh_node_agent.scheduler import AdmissionError, DeviceQueue
def _node_config(tmp_path: Path) -> NodeAgentConfig:
@ -152,3 +154,174 @@ def test_chat_completions_returns_503_when_server_startup_fails(tmp_path):
assert response.status_code == 503
assert response.json()["error"]["code"] == "server_startup_error"
asyncio.run(app.state.http.aclose())
def test_select_device_prefers_already_loaded_model():
devices = [
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
]
metrics = [
DeviceMetrics(
device=devices[0],
loaded_model_id="other-model",
mem_total_gb=24.0,
mem_used_gb=4.0,
),
DeviceMetrics(
device=devices[1],
loaded_model_id="target-model",
mem_total_gb=24.0,
mem_used_gb=20.0,
),
]
picked = _select_device(devices, metrics, model_id="target-model")
assert picked == devices[1]
def test_select_device_prefers_more_free_memory_and_lower_pressure():
devices = [
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
]
metrics = [
DeviceMetrics(
device=devices[0],
mem_total_gb=24.0,
mem_used_gb=18.0,
queue_depth=0,
in_flight_jobs=0,
utilization_pct=5.0,
),
DeviceMetrics(
device=devices[1],
mem_total_gb=24.0,
mem_used_gb=4.0,
queue_depth=0,
in_flight_jobs=0,
utilization_pct=10.0,
),
]
picked = _select_device(devices, metrics, model_id="target-model")
assert picked == devices[1]
def test_chat_completions_uses_selected_device_not_first_device(tmp_path):
from rolemesh_node_agent.main import create_app
cfg = _node_config(tmp_path)
app = create_app(cfg)
calls = {}
devices = [
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
]
async def fake_discover_devices():
return devices
async def fake_get_metrics():
return [
DeviceMetrics(
device=devices[0],
mem_total_gb=24.0,
mem_used_gb=20.0,
),
DeviceMetrics(
device=devices[1],
mem_total_gb=24.0,
mem_used_gb=2.0,
),
]
async def fake_ensure_server(device, *, model_path, model_id, server_args):
calls["device"] = device.id
return "http://127.0.0.1:9100"
async def fake_chat(base_url, payload):
return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
app.state.cuda.discover_devices = fake_discover_devices
app.state.cuda.get_metrics = fake_get_metrics
app.state.cuda.ensure_server = fake_ensure_server
app.state.upstream.chat_completions = fake_chat
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
json={
"model": "planner-gguf",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
assert response.status_code == 200
assert calls["device"] == "gpu:1"
asyncio.run(app.state.http.aclose())
def test_merge_scheduler_metrics_overlays_queue_state():
device = DeviceRef(kind="gpu", backend="cuda", id="gpu:0")
metric = DeviceMetrics(device=device)
queue = DeviceQueue(max_pending=2)
queue._queued = 1
queue._in_flight = 1
merged = _merge_scheduler_metrics([metric], {"gpu:0": queue})
assert merged[0].queue_depth == 1
assert merged[0].in_flight_jobs == 1
def test_device_queue_rejects_when_full():
queue = DeviceQueue(max_pending=1)
queue._in_flight = 1
try:
asyncio.run(queue.acquire())
except AdmissionError as exc:
assert "full" in str(exc)
else:
raise AssertionError("expected AdmissionError")
def test_chat_completions_returns_429_when_device_queue_is_full(tmp_path):
from rolemesh_node_agent.main import create_app
cfg = _node_config(tmp_path)
cfg.max_pending_requests_per_device = 1
app = create_app(cfg)
async def fake_discover_devices():
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
async def fake_get_metrics():
return [DeviceMetrics(device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"))]
app.state.cuda.discover_devices = fake_discover_devices
app.state.cuda.get_metrics = fake_get_metrics
saturated = DeviceQueue(max_pending=1)
saturated._in_flight = 1
app.state.device_queues["gpu:0"] = saturated
response = asyncio.run(
_request(
app,
"POST",
"/v1/chat/completions",
json={
"model": "planner-gguf",
"messages": [{"role": "user", "content": "hello"}],
},
)
)
assert response.status_code == 429
assert response.json()["error"]["code"] == "queue_full"
asyncio.run(app.state.http.aclose())

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import random
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
@ -44,3 +46,55 @@ def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
def test_registry_heartbeat_returns_none_for_unknown_node():
registry = Registry()
assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None
def test_registry_filters_stale_nodes_from_routing(tmp_path):
persist_path = tmp_path / "registry.json"
registry = Registry(persist_path=persist_path, stale_after_s=0.01)
fresh = registry.register(
NodeRegistration(
node_id="fresh-node",
base_url="http://127.0.0.1:9001",
roles=["reviewer"],
)
)
stale = registry.register(
NodeRegistration(
node_id="stale-node",
base_url="http://127.0.0.1:9002",
roles=["reviewer"],
)
)
stale.last_seen = stale.last_seen - 60
registry._save()
assert registry.is_stale(stale) is True
assert registry.is_stale(fresh) is False
assert [node.node_id for node in registry.nodes_for_role("reviewer", include_stale=False)] == ["fresh-node"]
assert registry.pick_node_for_role("reviewer").node_id == "fresh-node"
def test_registry_supports_random_selection(monkeypatch):
registry = Registry()
registry.register(
NodeRegistration(
node_id="node-a",
base_url="http://127.0.0.1:9001",
roles=["reviewer"],
)
)
registry.register(
NodeRegistration(
node_id="node-b",
base_url="http://127.0.0.1:9002",
roles=["reviewer"],
)
)
monkeypatch.setattr(random, "choice", lambda candidates: candidates[-1])
picked = registry.pick_node_for_role("reviewer", strategy="random")
assert picked is not None
assert picked.node_id == "node-b"