diff --git a/README.md b/README.md index be9cd87..a5fb22f 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,26 @@ deployments where **different machines host different models** (e.g., GPU box fo - Robust proxying with **explicit httpx timeouts** (no “hang forever”) - Structured logging with request IDs +## Roles Are Project-Defined + +The role names in this repository are examples, not a fixed taxonomy. + +- `planner`, `writer`, `coder`, and `reviewer` are only sample aliases +- you can add, remove, or rename roles per project +- a role is simply the `model` alias clients send to RoleMesh Gateway +- each role can point at any OpenAI-compatible backend that fits that project's workflow + +Examples of project-specific roles: +- `researcher` +- `summarizer` +- `tool-user` +- `swe-backend` +- `swe-frontend` +- `test-writer` +- `security-reviewer` + +If your workflow changes, update the `models:` section in config rather than treating the example roles as required. + ## Quick Start This is the fastest path to a working local setup. @@ -68,6 +88,8 @@ models: Save that as `configs/models.yaml`. +You are not limited to `planner` and `writer`. Those are just placeholders for whatever roles your project needs. + ### 4. Run the gateway ```bash @@ -166,6 +188,21 @@ This repository is a **preliminary scaffold**: - Gateway proxying has been exercised live with Ollama and `llamafile`. - Node-agent managed inference has been exercised live with `llama-server` on CUDA hardware. +## Availability Semantics + +RoleMesh Gateway now distinguishes between configured aliases and currently usable aliases. + +- `GET /v1/models` advertises only aliases whose upstreams are reachable right now +- unavailable aliases are reported under `rolemesh.unavailable_models` +- `GET /ready` returns `200` only when the configured `default_model` is currently usable +- `GET /health` remains a lightweight process health check and does not probe upstreams +- discovered nodes are removed from routing once they become stale + +This makes the API surface more truthful for clients that rely on the advertised role list. + +By default, registered nodes become stale after `30` seconds without a fresh heartbeat or registration event. +You can change that with `ROLE_MESH_NODE_STALE_AFTER_S`. + ## License MIT. See `LICENSE`. diff --git a/configs/models.example.yaml b/configs/models.example.yaml index fcd4c48..7e67a89 100644 --- a/configs/models.example.yaml +++ b/configs/models.example.yaml @@ -17,6 +17,8 @@ auth: # Models may be: # - type: proxy (static URL to an OpenAI-compatible upstream) # - type: discovered (resolved from registered nodes by role) +# The names under "models" are project-defined role aliases, not a fixed built-in list. +# Rename or replace planner/writer/coder/reviewer with whatever your workflow needs. models: planner: type: proxy diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index ac907ae..bed43ec 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -36,6 +36,6 @@ This is deliberately small so you can swap it out later for something stronger: - Auth is optional and config-driven rather than enforced by default - No TTL/health polling - No automatic config reload -- Round-robin selection only for discovered nodes +- Selection strategies are intentionally simple and limited to `round_robin` and `random` These are tracked in `docs/DEPLOYMENT.md` as next steps. diff --git a/docs/CONFIG.md b/docs/CONFIG.md index fa9bedc..8a11aa5 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -24,6 +24,33 @@ models: - `` is what clients pass as `model` in `/v1/chat/completions`. - `openai_model_name` is the model id returned by `/v1/models` (usually same as alias). +## Roles are aliases, not a fixed list + +RoleMesh Gateway does not reserve a built-in set of roles. + +- The keys under `models:` are your project-specific role names +- Clients send those keys in the OpenAI `model` field +- You can rename or replace the sample roles entirely +- Different projects can use different role layouts with the same gateway + +Example custom role set: + +```yaml +models: + researcher: + type: proxy + openai_model_name: researcher + proxy_url: http://127.0.0.1:8011 + summarizer: + type: proxy + openai_model_name: summarizer + proxy_url: http://127.0.0.1:8012 + security-reviewer: + type: proxy + openai_model_name: security-reviewer + proxy_url: http://127.0.0.1:8013 +``` + ## Proxy models Route to a fixed upstream (any host reachable from the gateway): @@ -57,6 +84,10 @@ models: strategy: round_robin ``` +Supported discovered-node strategies: +- `round_robin`: rotate requests across fresh matching nodes +- `random`: choose a fresh matching node at random for each request + ### Registering nodes Nodes register to `POST /v1/nodes/register`: @@ -78,6 +109,21 @@ Supported headers: - Clients: `Authorization: Bearer ` or `X-Api-Key: ` - Nodes: `Authorization: Bearer ` or `X-RoleMesh-Node-Key: ` +## Availability behavior + +Configured aliases are not automatically assumed healthy. + +- `GET /v1/models` probes configured upstreams and only returns aliases that are currently reachable +- unavailable aliases are included separately in `rolemesh.unavailable_models` +- `GET /ready` returns success only when the configured `default_model` is currently usable +- discovered nodes are only considered routable while they are fresh +- gateway metadata marks stale registered nodes so operators can distinguish them from healthy nodes + +This is especially important when a config contains multiple optional roles but only some backends are up. + +For discovered-node freshness, the gateway uses the `ROLE_MESH_NODE_STALE_AFTER_S` environment variable. +Default: `30`. + ## Quick example ```yaml diff --git a/docs/DEPLOYMENT.md b/docs/DEPLOYMENT.md index 9f650b4..02c9aee 100644 --- a/docs/DEPLOYMENT.md +++ b/docs/DEPLOYMENT.md @@ -53,6 +53,22 @@ curl -sS -X POST http://127.0.0.1:8000/v1/chat/completions \ }' ``` +### Readiness and model advertisement + +- `GET /health` only checks that the gateway process is up +- `GET /ready` checks whether the configured default route is actually usable +- `GET /v1/models` only lists aliases with a currently reachable upstream +- aliases that are configured but currently unavailable are reported in `rolemesh.unavailable_models` +- discovered nodes that have not checked in recently are marked stale and excluded from routing + +### Stale node timeout + +Registered nodes age out of discovered-role routing after a heartbeat timeout. + +- default timeout: `30` seconds +- configure with `ROLE_MESH_NODE_STALE_AFTER_S` +- stale nodes remain visible for operators in the gateway metadata, but they no longer receive traffic + ## Network binding and exposure (Step 2 hardening) **Defaults are safe-by-default:** the gateway and node-agent CLIs default to binding on `127.0.0.1` (localhost). @@ -101,6 +117,7 @@ This scaffold supports two patterns. - `http://10.0.0.13:8011` (planner) - Update `proxy_url` entries to those LAN URLs, **or** use discovery: - Set model to `type: discovered` with `role: writer`, etc. + - Choose `strategy: round_robin` or `strategy: random` per discovered alias - Each host registers itself with the gateway. ### Minimal registration call @@ -118,5 +135,5 @@ curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \ - Configure API keys for: - inference endpoints via `auth.client_api_keys` - node registration and heartbeat via `auth.node_api_keys` -- Add TTL and periodic health checks for registered nodes +- Tune `ROLE_MESH_NODE_STALE_AFTER_S` for your heartbeat interval and failure tolerance - Consider mTLS if registration happens over untrusted networks diff --git a/docs/NODE_AGENT.md b/docs/NODE_AGENT.md index a9e8511..6a8e5de 100644 --- a/docs/NODE_AGENT.md +++ b/docs/NODE_AGENT.md @@ -16,6 +16,12 @@ Model switching is handled by **restart** in the scaffold. The agent now waits for the replacement `llama-server` to report readiness before proxying the first request. If startup or switching takes too long, the request fails with a `503` instead of passing through a transient upstream "Loading model" error. +Device selection is still simple, but it is no longer hard-coded to the first GPU: + +- first preference: a device that already has the requested model loaded +- otherwise: the device with the most free VRAM and least queue pressure +- requests are serialized per device +- each device has a bounded pending-request limit for backpressure ## Backends @@ -41,10 +47,12 @@ Two config knobs control how long the node agent waits for a managed `llama-serv ```yaml llama_server_startup_timeout_s: 30.0 llama_server_probe_interval_s: 0.5 +max_pending_requests_per_device: 2 ``` - `llama_server_startup_timeout_s`: maximum time to wait for a newly started or switched model - `llama_server_probe_interval_s`: polling interval for readiness checks +- `max_pending_requests_per_device`: maximum in-flight plus queued requests allowed per device before new requests are rejected The readiness probe checks the managed server's local `GET /health` and `GET /v1/models` endpoints. diff --git a/src/rolemesh_gateway/api/openai.py b/src/rolemesh_gateway/api/openai.py index 67cd048..cd7d097 100644 --- a/src/rolemesh_gateway/api/openai.py +++ b/src/rolemesh_gateway/api/openai.py @@ -1,9 +1,9 @@ from __future__ import annotations -import os +import asyncio import time import uuid -from typing import Any, Dict +from typing import Any, Dict, List from fastapi import APIRouter, Request, Depends from fastapi.responses import JSONResponse, StreamingResponse @@ -30,28 +30,107 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int = ) +async def _model_status( + alias: str, + entry: ProxyModel | DiscoveredModel, + registry: Registry, + upstream: UpstreamClient, +) -> Dict[str, Any]: + if isinstance(entry, ProxyModel): + base_url = str(entry.proxy_url).rstrip("/") + try: + await upstream.get_models(base_url) + return {"alias": alias, "available": True, "base_url": base_url} + except UpstreamError as exc: + return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)} + + matching_nodes = registry.nodes_for_role(entry.role, include_stale=False) + if not matching_nodes: + stale_nodes = registry.nodes_for_role(entry.role, include_stale=True) + error = "no_registered_nodes" + if stale_nodes: + error = "no_fresh_registered_nodes" + return {"alias": alias, "available": False, "role": entry.role, "error": error} + + for node in matching_nodes: + base_url = str(node.base_url).rstrip("/") + try: + await upstream.get_models(base_url) + return { + "alias": alias, + "available": True, + "role": entry.role, + "base_url": base_url, + "node_id": node.node_id, + } + except UpstreamError: + continue + + return { + "alias": alias, + "available": False, + "role": entry.role, + "error": "no_healthy_registered_nodes", + } + + +async def _collect_model_statuses( + cfg: Config, + registry: Registry, + upstream: UpstreamClient, +) -> Dict[str, Dict[str, Any]]: + statuses = await asyncio.gather( + *[ + _model_status(alias, entry, registry, upstream) + for alias, entry in cfg.models.items() + ] + ) + return {status["alias"]: status for status in statuses} + + @router.get("/v1/models") async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]: cfg: Config = request.app.state.cfg registry: Registry = request.app.state.registry + upstream: UpstreamClient = request.app.state.upstream + statuses = await _collect_model_statuses(cfg, registry, upstream) data = [] for name, entry in cfg.models.items(): + status = statuses[name] + if not status["available"]: + continue item = { "id": entry.openai_model_name, "object": "model", "owned_by": "local", } if isinstance(entry, ProxyModel): - item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url)} + item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True} else: - item["rolemesh"] = {"type": "discovered", "role": entry.role, "strategy": entry.strategy} + item["rolemesh"] = { + "type": "discovered", + "role": entry.role, + "strategy": entry.strategy, + "available": True, + } data.append(item) # Expose currently registered nodes (informational) - nodes = [n.model_dump(mode="json") for n in registry.list_nodes()] + nodes = [ + n.model_dump(mode="json") | {"stale": registry.is_stale(n)} + for n in registry.list_nodes(include_stale=True) + ] + unavailable = [status for status in statuses.values() if not status["available"]] - return {"object": "list", "data": data, "rolemesh": {"registered_nodes": nodes}} + return { + "object": "list", + "data": data, + "rolemesh": { + "registered_nodes": nodes, + "unavailable_models": unavailable, + }, + } @router.post("/v1/chat/completions") @@ -81,7 +160,7 @@ async def chat_completions(request: Request, _=Depends(require_client_auth)) -> if isinstance(entry, ProxyModel): base_url = str(entry.proxy_url).rstrip("/") elif isinstance(entry, DiscoveredModel): - node = registry.pick_node_for_role(entry.role) + node = registry.pick_node_for_role(entry.role, strategy=entry.strategy) if not node: return _openai_error( f"No registered nodes available for role '{entry.role}'. " @@ -125,12 +204,33 @@ async def health() -> Dict[str, str]: @router.get("/ready") async def ready(request: Request) -> JSONResponse: - """ - Readiness checks for presence of config and (optionally) upstreams. - For now: verifies config loads and returns 200. - """ cfg: Config = request.app.state.cfg - return JSONResponse(status_code=200, content={"status": "ready", "default_model": cfg.default_model}) + registry: Registry = request.app.state.registry + upstream: UpstreamClient = request.app.state.upstream + statuses = await _collect_model_statuses(cfg, registry, upstream) + default_status = statuses.get(cfg.default_model) + available_aliases = [alias for alias, status in statuses.items() if status["available"]] + + if default_status and default_status["available"]: + return JSONResponse( + status_code=200, + content={ + "status": "ready", + "default_model": cfg.default_model, + "available_models": available_aliases, + "unavailable_models": [s for s in statuses.values() if not s["available"]], + }, + ) + + return JSONResponse( + status_code=503, + content={ + "status": "not_ready", + "default_model": cfg.default_model, + "available_models": available_aliases, + "unavailable_models": [s for s in statuses.values() if not s["available"]], + }, + ) @router.post("/v1/nodes/register") diff --git a/src/rolemesh_gateway/main.py b/src/rolemesh_gateway/main.py index 5f3e281..7f74fff 100644 --- a/src/rolemesh_gateway/main.py +++ b/src/rolemesh_gateway/main.py @@ -27,12 +27,19 @@ def _get_logger() -> logging.Logger: def create_app( config_path: str | Path | None = None, registry_path: str | Path | None = None, + registry_stale_after_s: float | None = None, ) -> FastAPI: cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml") cfg = load_config(cfg_path) resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json") - registry = Registry(persist_path=Path(resolved_registry_path)) + resolved_registry_stale_after_s = registry_stale_after_s + if resolved_registry_stale_after_s is None: + resolved_registry_stale_after_s = float(os.environ.get("ROLE_MESH_NODE_STALE_AFTER_S", "30")) + registry = Registry( + persist_path=Path(resolved_registry_path), + stale_after_s=resolved_registry_stale_after_s, + ) upstream = UpstreamClient( connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")), diff --git a/src/rolemesh_gateway/registry.py b/src/rolemesh_gateway/registry.py index 8f623c0..5488fc2 100644 --- a/src/rolemesh_gateway/registry.py +++ b/src/rolemesh_gateway/registry.py @@ -1,9 +1,10 @@ from __future__ import annotations import json +import random import time from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, HttpUrl @@ -44,10 +45,11 @@ class Registry: - TTL + health checks """ - def __init__(self, persist_path: Optional[Path] = None) -> None: + def __init__(self, persist_path: Optional[Path] = None, stale_after_s: float = 30.0) -> None: self._nodes: Dict[str, RegisteredNode] = {} self._rr_counters: Dict[str, int] = {} self._persist_path = persist_path + self._stale_after_s = stale_after_s if self._persist_path: self._load() @@ -100,13 +102,29 @@ class Registry: self._save() return n - def list_nodes(self) -> List[RegisteredNode]: - return list(self._nodes.values()) + def is_stale(self, node: RegisteredNode, now: Optional[float] = None) -> bool: + if self._stale_after_s <= 0: + return False + now = time.time() if now is None else now + return (now - node.last_seen) > self._stale_after_s - def pick_node_for_role(self, role: str) -> Optional[RegisteredNode]: - candidates = [n for n in self._nodes.values() if role in n.roles] + def list_nodes(self, *, include_stale: bool = True) -> List[RegisteredNode]: + nodes = list(self._nodes.values()) + if include_stale: + return nodes + now = time.time() + return [node for node in nodes if not self.is_stale(node, now=now)] + + def nodes_for_role(self, role: str, *, include_stale: bool = False) -> List[RegisteredNode]: + nodes = self.list_nodes(include_stale=include_stale) + return [node for node in nodes if role in node.roles] + + def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]: + candidates = self.nodes_for_role(role, include_stale=False) if not candidates: return None + if strategy == "random": + return random.choice(candidates) idx = self._rr_counters.get(role, 0) % len(candidates) self._rr_counters[role] = idx + 1 self._save() diff --git a/src/rolemesh_node_agent/config.py b/src/rolemesh_node_agent/config.py index 3d454f9..a36242f 100644 --- a/src/rolemesh_node_agent/config.py +++ b/src/rolemesh_node_agent/config.py @@ -40,3 +40,4 @@ class NodeAgentConfig(BaseModel): llama_server_bin: str = "llama-server" llama_server_startup_timeout_s: float = 30.0 llama_server_probe_interval_s: float = 0.5 + max_pending_requests_per_device: int = 2 diff --git a/src/rolemesh_node_agent/main.py b/src/rolemesh_node_agent/main.py index e8147c9..3735ee7 100644 --- a/src/rolemesh_node_agent/main.py +++ b/src/rolemesh_node_agent/main.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager, suppress import time -from typing import Any, Dict +from typing import Any, Dict, Iterable import httpx from fastapi import FastAPI, Request @@ -11,9 +11,10 @@ from fastapi.responses import JSONResponse, StreamingResponse from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client from .adapters.cuda import CudaAdapter, ServerStartupError -from .adapters.base import DeviceRef +from .adapters.base import DeviceMetrics, DeviceRef from .config import NodeAgentConfig from .inventory import discover_gguf_models +from .scheduler import AdmissionError, DeviceQueue def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse: @@ -23,6 +24,54 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS ) +def _merge_scheduler_metrics( + metrics: Iterable[DeviceMetrics], + queues: Dict[str, DeviceQueue], +) -> list[DeviceMetrics]: + out: list[DeviceMetrics] = [] + for metric in metrics: + snapshot = queues.get(metric.device.id) + if snapshot is not None: + queue = snapshot.snapshot() + metric.queue_depth = queue.queue_depth + metric.in_flight_jobs = queue.in_flight + out.append(metric) + return out + + +def _select_device( + devices: Iterable[DeviceRef], + metrics: Iterable[DeviceMetrics], + *, + model_id: str, +) -> DeviceRef | None: + device_list = list(devices) + if not device_list: + return None + + metrics_by_id = {metric.device.id: metric for metric in metrics} + + # Reuse a device that already has the requested model loaded. + for device in device_list: + metric = metrics_by_id.get(device.id) + if metric and metric.loaded_model_id == model_id: + return device + + def score(device: DeviceRef) -> tuple[float, float, float, str]: + metric = metrics_by_id.get(device.id) + if metric is None: + return (float("-inf"), 0.0, 0.0, device.id) + + free_mem = 0.0 + if metric.mem_total_gb is not None and metric.mem_used_gb is not None: + free_mem = metric.mem_total_gb - metric.mem_used_gb + queue_penalty = float(metric.queue_depth + metric.in_flight_jobs) + util_penalty = float(metric.utilization_pct or 0.0) + return (free_mem, -queue_penalty, -util_penalty, device.id) + + return max(device_list, key=score) + + def create_app(cfg: NodeAgentConfig) -> FastAPI: http = httpx.AsyncClient( timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0) @@ -57,6 +106,7 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI: # Adapters app.state.cuda = cuda + app.state.device_queues: Dict[str, DeviceQueue] = {} # State: role -> (device, model) # This is intentionally simple for the scaffold: pick first GPU and first matching model. @@ -69,7 +119,9 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI: @app.get("/v1/node/inventory") async def inventory() -> Dict[str, Any]: devices = await app.state.cuda.discover_devices() - metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()] + raw_metrics = await app.state.cuda.get_metrics() + merged_metrics = _merge_scheduler_metrics(raw_metrics, app.state.device_queues) + metrics = [m.__dict__ | {"device": m.device.__dict__} for m in merged_metrics] models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models] discovered = discover_gguf_models(cfg.model_roots) return { @@ -104,30 +156,48 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI: devices = await app.state.cuda.discover_devices() if not devices: return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503) - device = devices[0] + metrics = _merge_scheduler_metrics(await app.state.cuda.get_metrics(), app.state.device_queues) + device = _select_device(devices, metrics, model_id=model_entry.model_id) + if device is None: + return _error("No eligible CUDA GPUs discovered on this node.", code="no_device", status_code=503) + queue = app.state.device_queues.setdefault( + device.id, + DeviceQueue(max_pending=cfg.max_pending_requests_per_device), + ) + try: + lease = await queue.acquire() + except AdmissionError: + return _error("Device queue is full.", code="queue_full", status_code=429) try: - base_url = await app.state.cuda.ensure_server( - device, - model_path=str(model_entry.path), - model_id=model_entry.model_id, - server_args=model_entry.server_args, - ) - except ServerStartupError as e: - return _error(str(e), code="server_startup_error", status_code=503) + try: + base_url = await app.state.cuda.ensure_server( + device, + model_path=str(model_entry.path), + model_id=model_entry.model_id, + server_args=model_entry.server_args, + ) + except ServerStartupError as e: + return _error(str(e), code="server_startup_error", status_code=503) - upstream = app.state.upstream - try: + upstream = app.state.upstream + try: + if not stream: + out = await upstream.chat_completions(base_url, body) + return JSONResponse(status_code=200, content=out) + else: + async def gen(): + try: + async for chunk in upstream.stream_chat_completions(base_url, body): + yield chunk + finally: + await lease.release() + return StreamingResponse(gen(), media_type="text/event-stream") + except UpstreamError as e: + return _error(str(e), code="upstream_error", status_code=502) + finally: if not stream: - out = await upstream.chat_completions(base_url, body) - return JSONResponse(status_code=200, content=out) - else: - async def gen(): - async for chunk in upstream.stream_chat_completions(base_url, body): - yield chunk - return StreamingResponse(gen(), media_type="text/event-stream") - except UpstreamError as e: - return _error(str(e), code="upstream_error", status_code=502) + await lease.release() return app diff --git a/src/rolemesh_node_agent/scheduler.py b/src/rolemesh_node_agent/scheduler.py index ba6948f..f239f9a 100644 --- a/src/rolemesh_node_agent/scheduler.py +++ b/src/rolemesh_node_agent/scheduler.py @@ -1,59 +1,66 @@ from __future__ import annotations import asyncio -import time from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, Optional -import httpx + +class AdmissionError(RuntimeError): + pass @dataclass -class Job: - job_id: str - submitted_at: float - base_url: str - body: Dict[str, Any] - stream: bool +class QueueSnapshot: + queue_depth: int + in_flight: int + + +class DeviceLease: + def __init__(self, queue: "DeviceQueue") -> None: + self._queue = queue + self._released = False + + async def release(self) -> None: + if self._released: + return + self._released = True + await self._queue._release() + + async def __aenter__(self) -> "DeviceLease": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self.release() class DeviceQueue: - """Simple FIFO queue per device; one worker per device for latency predictability.""" + """Serialize work per device and enforce a bounded number of pending requests.""" - def __init__(self, *, client: httpx.AsyncClient, base_url: str) -> None: - self._client = client - self._base_url = base_url.rstrip("/") - self._q: asyncio.Queue[Job] = asyncio.Queue() - self._worker_task: Optional[asyncio.Task] = None - self.in_flight: int = 0 - self.last_job_started_at: Optional[float] = None + def __init__(self, *, max_pending: int = 2) -> None: + self._max_pending = max_pending + self._lock = asyncio.Lock() + self._guard = asyncio.Lock() + self._queued = 0 + self._in_flight = 0 - def start(self) -> None: - if self._worker_task is None: - self._worker_task = asyncio.create_task(self._worker()) + async def acquire(self) -> DeviceLease: + async with self._guard: + pending = self._queued + self._in_flight + if self._max_pending > 0 and pending >= self._max_pending: + raise AdmissionError("device queue is full") + self._queued += 1 - async def stop(self) -> None: - if self._worker_task: - self._worker_task.cancel() - self._worker_task = None + await self._lock.acquire() - async def submit(self, job: Job) -> Job: - await self._q.put(job) - return job + async with self._guard: + self._queued -= 1 + self._in_flight += 1 - @property - def depth(self) -> int: - return self._q.qsize() + return DeviceLease(self) - async def _worker(self) -> None: - while True: - job = await self._q.get() - self.in_flight += 1 - self.last_job_started_at = time.time() - try: - # The actual request execution is handled by the API layer (so it can stream). - # This worker exists mainly to serialize jobs per device and provide queue metrics. - await asyncio.sleep(0) # placeholder - finally: - self.in_flight -= 1 - self._q.task_done() + async def _release(self) -> None: + async with self._guard: + self._in_flight -= 1 + self._lock.release() + + def snapshot(self) -> QueueSnapshot: + return QueueSnapshot(queue_depth=self._queued, in_flight=self._in_flight) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 022997d..ebae53d 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -30,6 +30,12 @@ def _write_gateway_config(path: Path) -> Path: "role": "reviewer", "strategy": "round_robin", }, + "random-reviewer": { + "type": "discovered", + "openai_model_name": "random-reviewer", + "role": "reviewer", + "strategy": "random", + }, }, } ) @@ -43,6 +49,7 @@ def _create_gateway_app(tmp_path: Path): return create_app( _write_gateway_config(tmp_path / "models.yaml"), registry_path=tmp_path / "registry.json", + registry_stale_after_s=30.0, ) @@ -54,6 +61,10 @@ async def _request(app, method: str, path: str, **kwargs) -> httpx.Response: def test_models_requires_client_auth(tmp_path): app = _create_gateway_app(tmp_path) + async def fake_get_models(base_url): + return {"object": "list", "data": []} + + app.state.upstream.get_models = fake_get_models unauthorized = asyncio.run(_request(app, "GET", "/v1/models")) assert unauthorized.status_code == 401 @@ -62,7 +73,8 @@ def test_models_requires_client_auth(tmp_path): ) assert authorized.status_code == 200 body = authorized.json() - assert {item["id"] for item in body["data"]} == {"writer", "reviewer"} + assert {item["id"] for item in body["data"]} == {"writer"} + assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"reviewer", "random-reviewer"} asyncio.run(app.state.upstream.close()) @@ -156,3 +168,184 @@ def test_node_registration_and_heartbeat_update_registry_state(tmp_path): assert node["status"]["timestamp"] == 123.0 assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}] asyncio.run(app.state.upstream.close()) + + +def test_models_filters_unhealthy_aliases_and_reports_them(tmp_path): + app = _create_gateway_app(tmp_path) + + async def fake_get_models(base_url): + if base_url == "http://127.0.0.1:8012": + raise Exception("unexpected") + return {"object": "list", "data": []} + + from rolemesh_gateway.upstream import UpstreamError + + async def wrapped_get_models(base_url): + if base_url == "http://127.0.0.1:8012": + raise UpstreamError("Upstream unreachable: boom") + return await fake_get_models(base_url) + + app.state.upstream.get_models = wrapped_get_models + + response = asyncio.run( + _request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"}) + ) + + assert response.status_code == 200 + body = response.json() + assert body["data"] == [] + assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"writer", "reviewer", "random-reviewer"} + asyncio.run(app.state.upstream.close()) + + +def test_ready_returns_503_when_default_model_is_unavailable(tmp_path): + app = _create_gateway_app(tmp_path) + + from rolemesh_gateway.upstream import UpstreamError + + async def fake_get_models(base_url): + raise UpstreamError("Upstream unreachable: boom") + + app.state.upstream.get_models = fake_get_models + + response = asyncio.run(_request(app, "GET", "/ready")) + + assert response.status_code == 503 + body = response.json() + assert body["status"] == "not_ready" + assert body["default_model"] == "writer" + asyncio.run(app.state.upstream.close()) + + +def test_ready_returns_200_when_default_model_is_available(tmp_path): + app = _create_gateway_app(tmp_path) + + async def fake_get_models(base_url): + if base_url == "http://127.0.0.1:8012": + return {"object": "list", "data": []} + raise Exception("unexpected upstream probe") + + from rolemesh_gateway.upstream import UpstreamError + + async def wrapped_get_models(base_url): + if base_url == "http://127.0.0.1:8012": + return {"object": "list", "data": []} + raise UpstreamError("no reviewer nodes") + + app.state.upstream.get_models = wrapped_get_models + + response = asyncio.run(_request(app, "GET", "/ready")) + + assert response.status_code == 200 + body = response.json() + assert body["status"] == "ready" + assert body["default_model"] == "writer" + assert body["available_models"] == ["writer"] + asyncio.run(app.state.upstream.close()) + + +def test_stale_discovered_nodes_are_not_advertised_as_available(tmp_path): + app = _create_gateway_app(tmp_path) + + register = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/register", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-a", + "base_url": "http://127.0.0.1:9001", + "roles": ["reviewer"], + }, + ) + ) + assert register.status_code == 200 + + node = app.state.registry.list_nodes(include_stale=True)[0] + node.last_seen = node.last_seen - 120 + + async def fake_get_models(base_url): + return {"object": "list", "data": []} + + app.state.upstream.get_models = fake_get_models + + models = asyncio.run( + _request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"}) + ) + ready = asyncio.run(_request(app, "GET", "/ready")) + + models_body = models.json() + assert {item["id"] for item in models_body["data"]} == {"writer"} + assert models_body["rolemesh"]["registered_nodes"][0]["stale"] is True + unavailable = {item["alias"]: item for item in models_body["rolemesh"]["unavailable_models"]} + assert unavailable["reviewer"]["error"] == "no_fresh_registered_nodes" + assert unavailable["random-reviewer"]["error"] == "no_fresh_registered_nodes" + assert ready.status_code == 200 + asyncio.run(app.state.upstream.close()) + + +def test_discovered_model_uses_random_strategy_when_configured(tmp_path): + app = _create_gateway_app(tmp_path) + calls = {} + + register_a = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/register", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-a", + "base_url": "http://127.0.0.1:9001", + "roles": ["reviewer"], + }, + ) + ) + register_b = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/register", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-b", + "base_url": "http://127.0.0.1:9002", + "roles": ["reviewer"], + }, + ) + ) + assert register_a.status_code == 200 + assert register_b.status_code == 200 + + import random + + original_choice = random.choice + random.choice = lambda candidates: candidates[-1] + + async def fake_chat(base_url, payload): + calls["base_url"] = base_url + return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + + app.state.upstream.chat_completions = fake_chat + app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []} + + try: + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + headers={"x-api-key": "client-secret"}, + json={ + "model": "random-reviewer", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + finally: + random.choice = original_choice + + assert response.status_code == 200 + assert calls["base_url"] == "http://127.0.0.1:9002" + asyncio.run(app.state.upstream.close()) diff --git a/tests/test_node_agent.py b/tests/test_node_agent.py index dfd9bf8..9236cab 100644 --- a/tests/test_node_agent.py +++ b/tests/test_node_agent.py @@ -7,6 +7,8 @@ import httpx from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig +from rolemesh_node_agent.main import _merge_scheduler_metrics, _select_device +from rolemesh_node_agent.scheduler import AdmissionError, DeviceQueue def _node_config(tmp_path: Path) -> NodeAgentConfig: @@ -152,3 +154,174 @@ def test_chat_completions_returns_503_when_server_startup_fails(tmp_path): assert response.status_code == 503 assert response.json()["error"]["code"] == "server_startup_error" asyncio.run(app.state.http.aclose()) + + +def test_select_device_prefers_already_loaded_model(): + devices = [ + DeviceRef(kind="gpu", backend="cuda", id="gpu:0"), + DeviceRef(kind="gpu", backend="cuda", id="gpu:1"), + ] + metrics = [ + DeviceMetrics( + device=devices[0], + loaded_model_id="other-model", + mem_total_gb=24.0, + mem_used_gb=4.0, + ), + DeviceMetrics( + device=devices[1], + loaded_model_id="target-model", + mem_total_gb=24.0, + mem_used_gb=20.0, + ), + ] + + picked = _select_device(devices, metrics, model_id="target-model") + assert picked == devices[1] + + +def test_select_device_prefers_more_free_memory_and_lower_pressure(): + devices = [ + DeviceRef(kind="gpu", backend="cuda", id="gpu:0"), + DeviceRef(kind="gpu", backend="cuda", id="gpu:1"), + ] + metrics = [ + DeviceMetrics( + device=devices[0], + mem_total_gb=24.0, + mem_used_gb=18.0, + queue_depth=0, + in_flight_jobs=0, + utilization_pct=5.0, + ), + DeviceMetrics( + device=devices[1], + mem_total_gb=24.0, + mem_used_gb=4.0, + queue_depth=0, + in_flight_jobs=0, + utilization_pct=10.0, + ), + ] + + picked = _select_device(devices, metrics, model_id="target-model") + assert picked == devices[1] + + +def test_chat_completions_uses_selected_device_not_first_device(tmp_path): + from rolemesh_node_agent.main import create_app + + cfg = _node_config(tmp_path) + app = create_app(cfg) + calls = {} + + devices = [ + DeviceRef(kind="gpu", backend="cuda", id="gpu:0"), + DeviceRef(kind="gpu", backend="cuda", id="gpu:1"), + ] + + async def fake_discover_devices(): + return devices + + async def fake_get_metrics(): + return [ + DeviceMetrics( + device=devices[0], + mem_total_gb=24.0, + mem_used_gb=20.0, + ), + DeviceMetrics( + device=devices[1], + mem_total_gb=24.0, + mem_used_gb=2.0, + ), + ] + + async def fake_ensure_server(device, *, model_path, model_id, server_args): + calls["device"] = device.id + return "http://127.0.0.1:9100" + + async def fake_chat(base_url, payload): + return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + + app.state.cuda.discover_devices = fake_discover_devices + app.state.cuda.get_metrics = fake_get_metrics + app.state.cuda.ensure_server = fake_ensure_server + app.state.upstream.chat_completions = fake_chat + + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + json={ + "model": "planner-gguf", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 200 + assert calls["device"] == "gpu:1" + asyncio.run(app.state.http.aclose()) + + +def test_merge_scheduler_metrics_overlays_queue_state(): + device = DeviceRef(kind="gpu", backend="cuda", id="gpu:0") + metric = DeviceMetrics(device=device) + queue = DeviceQueue(max_pending=2) + queue._queued = 1 + queue._in_flight = 1 + + merged = _merge_scheduler_metrics([metric], {"gpu:0": queue}) + + assert merged[0].queue_depth == 1 + assert merged[0].in_flight_jobs == 1 + + +def test_device_queue_rejects_when_full(): + queue = DeviceQueue(max_pending=1) + queue._in_flight = 1 + + try: + asyncio.run(queue.acquire()) + except AdmissionError as exc: + assert "full" in str(exc) + else: + raise AssertionError("expected AdmissionError") + + +def test_chat_completions_returns_429_when_device_queue_is_full(tmp_path): + from rolemesh_node_agent.main import create_app + + cfg = _node_config(tmp_path) + cfg.max_pending_requests_per_device = 1 + app = create_app(cfg) + + async def fake_discover_devices(): + return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")] + + async def fake_get_metrics(): + return [DeviceMetrics(device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"))] + + app.state.cuda.discover_devices = fake_discover_devices + app.state.cuda.get_metrics = fake_get_metrics + saturated = DeviceQueue(max_pending=1) + saturated._in_flight = 1 + app.state.device_queues["gpu:0"] = saturated + + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + json={ + "model": "planner-gguf", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 429 + assert response.json()["error"]["code"] == "queue_full" + asyncio.run(app.state.http.aclose()) diff --git a/tests/test_registry.py b/tests/test_registry.py index 2701424..8287bf2 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,5 +1,7 @@ from __future__ import annotations +import random + from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry @@ -44,3 +46,55 @@ def test_registry_persists_round_robin_and_heartbeat_state(tmp_path): def test_registry_heartbeat_returns_none_for_unknown_node(): registry = Registry() assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None + + +def test_registry_filters_stale_nodes_from_routing(tmp_path): + persist_path = tmp_path / "registry.json" + registry = Registry(persist_path=persist_path, stale_after_s=0.01) + + fresh = registry.register( + NodeRegistration( + node_id="fresh-node", + base_url="http://127.0.0.1:9001", + roles=["reviewer"], + ) + ) + stale = registry.register( + NodeRegistration( + node_id="stale-node", + base_url="http://127.0.0.1:9002", + roles=["reviewer"], + ) + ) + + stale.last_seen = stale.last_seen - 60 + registry._save() + + assert registry.is_stale(stale) is True + assert registry.is_stale(fresh) is False + assert [node.node_id for node in registry.nodes_for_role("reviewer", include_stale=False)] == ["fresh-node"] + assert registry.pick_node_for_role("reviewer").node_id == "fresh-node" + + +def test_registry_supports_random_selection(monkeypatch): + registry = Registry() + registry.register( + NodeRegistration( + node_id="node-a", + base_url="http://127.0.0.1:9001", + roles=["reviewer"], + ) + ) + registry.register( + NodeRegistration( + node_id="node-b", + base_url="http://127.0.0.1:9002", + roles=["reviewer"], + ) + ) + + monkeypatch.setattr(random, "choice", lambda candidates: candidates[-1]) + + picked = registry.pick_node_for_role("reviewer", strategy="random") + assert picked is not None + assert picked.node_id == "node-b"