Addressing reliability issues from Didactopus use.
This commit is contained in:
parent
5befa6d7f6
commit
79923983a0
37
README.md
37
README.md
|
|
@ -19,6 +19,26 @@ deployments where **different machines host different models** (e.g., GPU box fo
|
|||
- Robust proxying with **explicit httpx timeouts** (no “hang forever”)
|
||||
- Structured logging with request IDs
|
||||
|
||||
## Roles Are Project-Defined
|
||||
|
||||
The role names in this repository are examples, not a fixed taxonomy.
|
||||
|
||||
- `planner`, `writer`, `coder`, and `reviewer` are only sample aliases
|
||||
- you can add, remove, or rename roles per project
|
||||
- a role is simply the `model` alias clients send to RoleMesh Gateway
|
||||
- each role can point at any OpenAI-compatible backend that fits that project's workflow
|
||||
|
||||
Examples of project-specific roles:
|
||||
- `researcher`
|
||||
- `summarizer`
|
||||
- `tool-user`
|
||||
- `swe-backend`
|
||||
- `swe-frontend`
|
||||
- `test-writer`
|
||||
- `security-reviewer`
|
||||
|
||||
If your workflow changes, update the `models:` section in config rather than treating the example roles as required.
|
||||
|
||||
## Quick Start
|
||||
|
||||
This is the fastest path to a working local setup.
|
||||
|
|
@ -68,6 +88,8 @@ models:
|
|||
|
||||
Save that as `configs/models.yaml`.
|
||||
|
||||
You are not limited to `planner` and `writer`. Those are just placeholders for whatever roles your project needs.
|
||||
|
||||
### 4. Run the gateway
|
||||
|
||||
```bash
|
||||
|
|
@ -166,6 +188,21 @@ This repository is a **preliminary scaffold**:
|
|||
- Gateway proxying has been exercised live with Ollama and `llamafile`.
|
||||
- Node-agent managed inference has been exercised live with `llama-server` on CUDA hardware.
|
||||
|
||||
## Availability Semantics
|
||||
|
||||
RoleMesh Gateway now distinguishes between configured aliases and currently usable aliases.
|
||||
|
||||
- `GET /v1/models` advertises only aliases whose upstreams are reachable right now
|
||||
- unavailable aliases are reported under `rolemesh.unavailable_models`
|
||||
- `GET /ready` returns `200` only when the configured `default_model` is currently usable
|
||||
- `GET /health` remains a lightweight process health check and does not probe upstreams
|
||||
- discovered nodes are removed from routing once they become stale
|
||||
|
||||
This makes the API surface more truthful for clients that rely on the advertised role list.
|
||||
|
||||
By default, registered nodes become stale after `30` seconds without a fresh heartbeat or registration event.
|
||||
You can change that with `ROLE_MESH_NODE_STALE_AFTER_S`.
|
||||
|
||||
## License
|
||||
|
||||
MIT. See `LICENSE`.
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ auth:
|
|||
# Models may be:
|
||||
# - type: proxy (static URL to an OpenAI-compatible upstream)
|
||||
# - type: discovered (resolved from registered nodes by role)
|
||||
# The names under "models" are project-defined role aliases, not a fixed built-in list.
|
||||
# Rename or replace planner/writer/coder/reviewer with whatever your workflow needs.
|
||||
models:
|
||||
planner:
|
||||
type: proxy
|
||||
|
|
|
|||
|
|
@ -36,6 +36,6 @@ This is deliberately small so you can swap it out later for something stronger:
|
|||
- Auth is optional and config-driven rather than enforced by default
|
||||
- No TTL/health polling
|
||||
- No automatic config reload
|
||||
- Round-robin selection only for discovered nodes
|
||||
- Selection strategies are intentionally simple and limited to `round_robin` and `random`
|
||||
|
||||
These are tracked in `docs/DEPLOYMENT.md` as next steps.
|
||||
|
|
|
|||
|
|
@ -24,6 +24,33 @@ models:
|
|||
- `<alias>` is what clients pass as `model` in `/v1/chat/completions`.
|
||||
- `openai_model_name` is the model id returned by `/v1/models` (usually same as alias).
|
||||
|
||||
## Roles are aliases, not a fixed list
|
||||
|
||||
RoleMesh Gateway does not reserve a built-in set of roles.
|
||||
|
||||
- The keys under `models:` are your project-specific role names
|
||||
- Clients send those keys in the OpenAI `model` field
|
||||
- You can rename or replace the sample roles entirely
|
||||
- Different projects can use different role layouts with the same gateway
|
||||
|
||||
Example custom role set:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
researcher:
|
||||
type: proxy
|
||||
openai_model_name: researcher
|
||||
proxy_url: http://127.0.0.1:8011
|
||||
summarizer:
|
||||
type: proxy
|
||||
openai_model_name: summarizer
|
||||
proxy_url: http://127.0.0.1:8012
|
||||
security-reviewer:
|
||||
type: proxy
|
||||
openai_model_name: security-reviewer
|
||||
proxy_url: http://127.0.0.1:8013
|
||||
```
|
||||
|
||||
## Proxy models
|
||||
|
||||
Route to a fixed upstream (any host reachable from the gateway):
|
||||
|
|
@ -57,6 +84,10 @@ models:
|
|||
strategy: round_robin
|
||||
```
|
||||
|
||||
Supported discovered-node strategies:
|
||||
- `round_robin`: rotate requests across fresh matching nodes
|
||||
- `random`: choose a fresh matching node at random for each request
|
||||
|
||||
### Registering nodes
|
||||
|
||||
Nodes register to `POST /v1/nodes/register`:
|
||||
|
|
@ -78,6 +109,21 @@ Supported headers:
|
|||
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
|
||||
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`
|
||||
|
||||
## Availability behavior
|
||||
|
||||
Configured aliases are not automatically assumed healthy.
|
||||
|
||||
- `GET /v1/models` probes configured upstreams and only returns aliases that are currently reachable
|
||||
- unavailable aliases are included separately in `rolemesh.unavailable_models`
|
||||
- `GET /ready` returns success only when the configured `default_model` is currently usable
|
||||
- discovered nodes are only considered routable while they are fresh
|
||||
- gateway metadata marks stale registered nodes so operators can distinguish them from healthy nodes
|
||||
|
||||
This is especially important when a config contains multiple optional roles but only some backends are up.
|
||||
|
||||
For discovered-node freshness, the gateway uses the `ROLE_MESH_NODE_STALE_AFTER_S` environment variable.
|
||||
Default: `30`.
|
||||
|
||||
## Quick example
|
||||
|
||||
```yaml
|
||||
|
|
|
|||
|
|
@ -53,6 +53,22 @@ curl -sS -X POST http://127.0.0.1:8000/v1/chat/completions \
|
|||
}'
|
||||
```
|
||||
|
||||
### Readiness and model advertisement
|
||||
|
||||
- `GET /health` only checks that the gateway process is up
|
||||
- `GET /ready` checks whether the configured default route is actually usable
|
||||
- `GET /v1/models` only lists aliases with a currently reachable upstream
|
||||
- aliases that are configured but currently unavailable are reported in `rolemesh.unavailable_models`
|
||||
- discovered nodes that have not checked in recently are marked stale and excluded from routing
|
||||
|
||||
### Stale node timeout
|
||||
|
||||
Registered nodes age out of discovered-role routing after a heartbeat timeout.
|
||||
|
||||
- default timeout: `30` seconds
|
||||
- configure with `ROLE_MESH_NODE_STALE_AFTER_S`
|
||||
- stale nodes remain visible for operators in the gateway metadata, but they no longer receive traffic
|
||||
|
||||
## Network binding and exposure (Step 2 hardening)
|
||||
|
||||
**Defaults are safe-by-default:** the gateway and node-agent CLIs default to binding on `127.0.0.1` (localhost).
|
||||
|
|
@ -101,6 +117,7 @@ This scaffold supports two patterns.
|
|||
- `http://10.0.0.13:8011` (planner)
|
||||
- Update `proxy_url` entries to those LAN URLs, **or** use discovery:
|
||||
- Set model to `type: discovered` with `role: writer`, etc.
|
||||
- Choose `strategy: round_robin` or `strategy: random` per discovered alias
|
||||
- Each host registers itself with the gateway.
|
||||
|
||||
### Minimal registration call
|
||||
|
|
@ -118,5 +135,5 @@ curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
|
|||
- Configure API keys for:
|
||||
- inference endpoints via `auth.client_api_keys`
|
||||
- node registration and heartbeat via `auth.node_api_keys`
|
||||
- Add TTL and periodic health checks for registered nodes
|
||||
- Tune `ROLE_MESH_NODE_STALE_AFTER_S` for your heartbeat interval and failure tolerance
|
||||
- Consider mTLS if registration happens over untrusted networks
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ Model switching is handled by **restart** in the scaffold.
|
|||
The agent now waits for the replacement `llama-server` to report readiness before proxying the first request.
|
||||
If startup or switching takes too long, the request fails with a `503` instead of passing through a transient upstream
|
||||
"Loading model" error.
|
||||
Device selection is still simple, but it is no longer hard-coded to the first GPU:
|
||||
|
||||
- first preference: a device that already has the requested model loaded
|
||||
- otherwise: the device with the most free VRAM and least queue pressure
|
||||
- requests are serialized per device
|
||||
- each device has a bounded pending-request limit for backpressure
|
||||
|
||||
## Backends
|
||||
|
||||
|
|
@ -41,10 +47,12 @@ Two config knobs control how long the node agent waits for a managed `llama-serv
|
|||
```yaml
|
||||
llama_server_startup_timeout_s: 30.0
|
||||
llama_server_probe_interval_s: 0.5
|
||||
max_pending_requests_per_device: 2
|
||||
```
|
||||
|
||||
- `llama_server_startup_timeout_s`: maximum time to wait for a newly started or switched model
|
||||
- `llama_server_probe_interval_s`: polling interval for readiness checks
|
||||
- `max_pending_requests_per_device`: maximum in-flight plus queued requests allowed per device before new requests are rejected
|
||||
|
||||
The readiness probe checks the managed server's local `GET /health` and `GET /v1/models` endpoints.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
|
@ -30,28 +30,107 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
|
|||
)
|
||||
|
||||
|
||||
async def _model_status(
|
||||
alias: str,
|
||||
entry: ProxyModel | DiscoveredModel,
|
||||
registry: Registry,
|
||||
upstream: UpstreamClient,
|
||||
) -> Dict[str, Any]:
|
||||
if isinstance(entry, ProxyModel):
|
||||
base_url = str(entry.proxy_url).rstrip("/")
|
||||
try:
|
||||
await upstream.get_models(base_url)
|
||||
return {"alias": alias, "available": True, "base_url": base_url}
|
||||
except UpstreamError as exc:
|
||||
return {"alias": alias, "available": False, "base_url": base_url, "error": str(exc)}
|
||||
|
||||
matching_nodes = registry.nodes_for_role(entry.role, include_stale=False)
|
||||
if not matching_nodes:
|
||||
stale_nodes = registry.nodes_for_role(entry.role, include_stale=True)
|
||||
error = "no_registered_nodes"
|
||||
if stale_nodes:
|
||||
error = "no_fresh_registered_nodes"
|
||||
return {"alias": alias, "available": False, "role": entry.role, "error": error}
|
||||
|
||||
for node in matching_nodes:
|
||||
base_url = str(node.base_url).rstrip("/")
|
||||
try:
|
||||
await upstream.get_models(base_url)
|
||||
return {
|
||||
"alias": alias,
|
||||
"available": True,
|
||||
"role": entry.role,
|
||||
"base_url": base_url,
|
||||
"node_id": node.node_id,
|
||||
}
|
||||
except UpstreamError:
|
||||
continue
|
||||
|
||||
return {
|
||||
"alias": alias,
|
||||
"available": False,
|
||||
"role": entry.role,
|
||||
"error": "no_healthy_registered_nodes",
|
||||
}
|
||||
|
||||
|
||||
async def _collect_model_statuses(
|
||||
cfg: Config,
|
||||
registry: Registry,
|
||||
upstream: UpstreamClient,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
statuses = await asyncio.gather(
|
||||
*[
|
||||
_model_status(alias, entry, registry, upstream)
|
||||
for alias, entry in cfg.models.items()
|
||||
]
|
||||
)
|
||||
return {status["alias"]: status for status in statuses}
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
|
||||
cfg: Config = request.app.state.cfg
|
||||
registry: Registry = request.app.state.registry
|
||||
upstream: UpstreamClient = request.app.state.upstream
|
||||
statuses = await _collect_model_statuses(cfg, registry, upstream)
|
||||
|
||||
data = []
|
||||
for name, entry in cfg.models.items():
|
||||
status = statuses[name]
|
||||
if not status["available"]:
|
||||
continue
|
||||
item = {
|
||||
"id": entry.openai_model_name,
|
||||
"object": "model",
|
||||
"owned_by": "local",
|
||||
}
|
||||
if isinstance(entry, ProxyModel):
|
||||
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url)}
|
||||
item["rolemesh"] = {"type": "proxy", "proxy_url": str(entry.proxy_url), "available": True}
|
||||
else:
|
||||
item["rolemesh"] = {"type": "discovered", "role": entry.role, "strategy": entry.strategy}
|
||||
item["rolemesh"] = {
|
||||
"type": "discovered",
|
||||
"role": entry.role,
|
||||
"strategy": entry.strategy,
|
||||
"available": True,
|
||||
}
|
||||
data.append(item)
|
||||
|
||||
# Expose currently registered nodes (informational)
|
||||
nodes = [n.model_dump(mode="json") for n in registry.list_nodes()]
|
||||
nodes = [
|
||||
n.model_dump(mode="json") | {"stale": registry.is_stale(n)}
|
||||
for n in registry.list_nodes(include_stale=True)
|
||||
]
|
||||
unavailable = [status for status in statuses.values() if not status["available"]]
|
||||
|
||||
return {"object": "list", "data": data, "rolemesh": {"registered_nodes": nodes}}
|
||||
return {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"rolemesh": {
|
||||
"registered_nodes": nodes,
|
||||
"unavailable_models": unavailable,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
|
|
@ -81,7 +160,7 @@ async def chat_completions(request: Request, _=Depends(require_client_auth)) ->
|
|||
if isinstance(entry, ProxyModel):
|
||||
base_url = str(entry.proxy_url).rstrip("/")
|
||||
elif isinstance(entry, DiscoveredModel):
|
||||
node = registry.pick_node_for_role(entry.role)
|
||||
node = registry.pick_node_for_role(entry.role, strategy=entry.strategy)
|
||||
if not node:
|
||||
return _openai_error(
|
||||
f"No registered nodes available for role '{entry.role}'. "
|
||||
|
|
@ -125,12 +204,33 @@ async def health() -> Dict[str, str]:
|
|||
|
||||
@router.get("/ready")
|
||||
async def ready(request: Request) -> JSONResponse:
|
||||
"""
|
||||
Readiness checks for presence of config and (optionally) upstreams.
|
||||
For now: verifies config loads and returns 200.
|
||||
"""
|
||||
cfg: Config = request.app.state.cfg
|
||||
return JSONResponse(status_code=200, content={"status": "ready", "default_model": cfg.default_model})
|
||||
registry: Registry = request.app.state.registry
|
||||
upstream: UpstreamClient = request.app.state.upstream
|
||||
statuses = await _collect_model_statuses(cfg, registry, upstream)
|
||||
default_status = statuses.get(cfg.default_model)
|
||||
available_aliases = [alias for alias, status in statuses.items() if status["available"]]
|
||||
|
||||
if default_status and default_status["available"]:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"status": "ready",
|
||||
"default_model": cfg.default_model,
|
||||
"available_models": available_aliases,
|
||||
"unavailable_models": [s for s in statuses.values() if not s["available"]],
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"status": "not_ready",
|
||||
"default_model": cfg.default_model,
|
||||
"available_models": available_aliases,
|
||||
"unavailable_models": [s for s in statuses.values() if not s["available"]],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/v1/nodes/register")
|
||||
|
|
|
|||
|
|
@ -27,12 +27,19 @@ def _get_logger() -> logging.Logger:
|
|||
def create_app(
|
||||
config_path: str | Path | None = None,
|
||||
registry_path: str | Path | None = None,
|
||||
registry_stale_after_s: float | None = None,
|
||||
) -> FastAPI:
|
||||
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
||||
cfg = load_config(cfg_path)
|
||||
|
||||
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
||||
registry = Registry(persist_path=Path(resolved_registry_path))
|
||||
resolved_registry_stale_after_s = registry_stale_after_s
|
||||
if resolved_registry_stale_after_s is None:
|
||||
resolved_registry_stale_after_s = float(os.environ.get("ROLE_MESH_NODE_STALE_AFTER_S", "30"))
|
||||
registry = Registry(
|
||||
persist_path=Path(resolved_registry_path),
|
||||
stale_after_s=resolved_registry_stale_after_s,
|
||||
)
|
||||
|
||||
upstream = UpstreamClient(
|
||||
connect_timeout_s=float(os.environ.get("ROLE_MESH_CONNECT_TIMEOUT_S", "10")),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
|
@ -44,10 +45,11 @@ class Registry:
|
|||
- TTL + health checks
|
||||
"""
|
||||
|
||||
def __init__(self, persist_path: Optional[Path] = None) -> None:
|
||||
def __init__(self, persist_path: Optional[Path] = None, stale_after_s: float = 30.0) -> None:
|
||||
self._nodes: Dict[str, RegisteredNode] = {}
|
||||
self._rr_counters: Dict[str, int] = {}
|
||||
self._persist_path = persist_path
|
||||
self._stale_after_s = stale_after_s
|
||||
|
||||
if self._persist_path:
|
||||
self._load()
|
||||
|
|
@ -100,13 +102,29 @@ class Registry:
|
|||
self._save()
|
||||
return n
|
||||
|
||||
def list_nodes(self) -> List[RegisteredNode]:
|
||||
return list(self._nodes.values())
|
||||
def is_stale(self, node: RegisteredNode, now: Optional[float] = None) -> bool:
|
||||
if self._stale_after_s <= 0:
|
||||
return False
|
||||
now = time.time() if now is None else now
|
||||
return (now - node.last_seen) > self._stale_after_s
|
||||
|
||||
def pick_node_for_role(self, role: str) -> Optional[RegisteredNode]:
|
||||
candidates = [n for n in self._nodes.values() if role in n.roles]
|
||||
def list_nodes(self, *, include_stale: bool = True) -> List[RegisteredNode]:
|
||||
nodes = list(self._nodes.values())
|
||||
if include_stale:
|
||||
return nodes
|
||||
now = time.time()
|
||||
return [node for node in nodes if not self.is_stale(node, now=now)]
|
||||
|
||||
def nodes_for_role(self, role: str, *, include_stale: bool = False) -> List[RegisteredNode]:
|
||||
nodes = self.list_nodes(include_stale=include_stale)
|
||||
return [node for node in nodes if role in node.roles]
|
||||
|
||||
def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]:
|
||||
candidates = self.nodes_for_role(role, include_stale=False)
|
||||
if not candidates:
|
||||
return None
|
||||
if strategy == "random":
|
||||
return random.choice(candidates)
|
||||
idx = self._rr_counters.get(role, 0) % len(candidates)
|
||||
self._rr_counters[role] = idx + 1
|
||||
self._save()
|
||||
|
|
|
|||
|
|
@ -40,3 +40,4 @@ class NodeAgentConfig(BaseModel):
|
|||
llama_server_bin: str = "llama-server"
|
||||
llama_server_startup_timeout_s: float = 30.0
|
||||
llama_server_probe_interval_s: float = 0.5
|
||||
max_pending_requests_per_device: int = 2
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
|
|
@ -11,9 +11,10 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
|
||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
|
||||
from .adapters.cuda import CudaAdapter, ServerStartupError
|
||||
from .adapters.base import DeviceRef
|
||||
from .adapters.base import DeviceMetrics, DeviceRef
|
||||
from .config import NodeAgentConfig
|
||||
from .inventory import discover_gguf_models
|
||||
from .scheduler import AdmissionError, DeviceQueue
|
||||
|
||||
|
||||
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
|
||||
|
|
@ -23,6 +24,54 @@ def _error(message: str, code: str = "node_error", status_code: int = 500) -> JS
|
|||
)
|
||||
|
||||
|
||||
def _merge_scheduler_metrics(
|
||||
metrics: Iterable[DeviceMetrics],
|
||||
queues: Dict[str, DeviceQueue],
|
||||
) -> list[DeviceMetrics]:
|
||||
out: list[DeviceMetrics] = []
|
||||
for metric in metrics:
|
||||
snapshot = queues.get(metric.device.id)
|
||||
if snapshot is not None:
|
||||
queue = snapshot.snapshot()
|
||||
metric.queue_depth = queue.queue_depth
|
||||
metric.in_flight_jobs = queue.in_flight
|
||||
out.append(metric)
|
||||
return out
|
||||
|
||||
|
||||
def _select_device(
|
||||
devices: Iterable[DeviceRef],
|
||||
metrics: Iterable[DeviceMetrics],
|
||||
*,
|
||||
model_id: str,
|
||||
) -> DeviceRef | None:
|
||||
device_list = list(devices)
|
||||
if not device_list:
|
||||
return None
|
||||
|
||||
metrics_by_id = {metric.device.id: metric for metric in metrics}
|
||||
|
||||
# Reuse a device that already has the requested model loaded.
|
||||
for device in device_list:
|
||||
metric = metrics_by_id.get(device.id)
|
||||
if metric and metric.loaded_model_id == model_id:
|
||||
return device
|
||||
|
||||
def score(device: DeviceRef) -> tuple[float, float, float, str]:
|
||||
metric = metrics_by_id.get(device.id)
|
||||
if metric is None:
|
||||
return (float("-inf"), 0.0, 0.0, device.id)
|
||||
|
||||
free_mem = 0.0
|
||||
if metric.mem_total_gb is not None and metric.mem_used_gb is not None:
|
||||
free_mem = metric.mem_total_gb - metric.mem_used_gb
|
||||
queue_penalty = float(metric.queue_depth + metric.in_flight_jobs)
|
||||
util_penalty = float(metric.utilization_pct or 0.0)
|
||||
return (free_mem, -queue_penalty, -util_penalty, device.id)
|
||||
|
||||
return max(device_list, key=score)
|
||||
|
||||
|
||||
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
||||
http = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=5.0, read=3600.0, write=30.0, pool=30.0)
|
||||
|
|
@ -57,6 +106,7 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|||
|
||||
# Adapters
|
||||
app.state.cuda = cuda
|
||||
app.state.device_queues: Dict[str, DeviceQueue] = {}
|
||||
|
||||
# State: role -> (device, model)
|
||||
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
||||
|
|
@ -69,7 +119,9 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|||
@app.get("/v1/node/inventory")
|
||||
async def inventory() -> Dict[str, Any]:
|
||||
devices = await app.state.cuda.discover_devices()
|
||||
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
|
||||
raw_metrics = await app.state.cuda.get_metrics()
|
||||
merged_metrics = _merge_scheduler_metrics(raw_metrics, app.state.device_queues)
|
||||
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in merged_metrics]
|
||||
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
|
||||
discovered = discover_gguf_models(cfg.model_roots)
|
||||
return {
|
||||
|
|
@ -104,30 +156,48 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|||
devices = await app.state.cuda.discover_devices()
|
||||
if not devices:
|
||||
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
|
||||
device = devices[0]
|
||||
metrics = _merge_scheduler_metrics(await app.state.cuda.get_metrics(), app.state.device_queues)
|
||||
device = _select_device(devices, metrics, model_id=model_entry.model_id)
|
||||
if device is None:
|
||||
return _error("No eligible CUDA GPUs discovered on this node.", code="no_device", status_code=503)
|
||||
queue = app.state.device_queues.setdefault(
|
||||
device.id,
|
||||
DeviceQueue(max_pending=cfg.max_pending_requests_per_device),
|
||||
)
|
||||
try:
|
||||
lease = await queue.acquire()
|
||||
except AdmissionError:
|
||||
return _error("Device queue is full.", code="queue_full", status_code=429)
|
||||
|
||||
try:
|
||||
base_url = await app.state.cuda.ensure_server(
|
||||
device,
|
||||
model_path=str(model_entry.path),
|
||||
model_id=model_entry.model_id,
|
||||
server_args=model_entry.server_args,
|
||||
)
|
||||
except ServerStartupError as e:
|
||||
return _error(str(e), code="server_startup_error", status_code=503)
|
||||
try:
|
||||
base_url = await app.state.cuda.ensure_server(
|
||||
device,
|
||||
model_path=str(model_entry.path),
|
||||
model_id=model_entry.model_id,
|
||||
server_args=model_entry.server_args,
|
||||
)
|
||||
except ServerStartupError as e:
|
||||
return _error(str(e), code="server_startup_error", status_code=503)
|
||||
|
||||
upstream = app.state.upstream
|
||||
try:
|
||||
upstream = app.state.upstream
|
||||
try:
|
||||
if not stream:
|
||||
out = await upstream.chat_completions(base_url, body)
|
||||
return JSONResponse(status_code=200, content=out)
|
||||
else:
|
||||
async def gen():
|
||||
try:
|
||||
async for chunk in upstream.stream_chat_completions(base_url, body):
|
||||
yield chunk
|
||||
finally:
|
||||
await lease.release()
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
except UpstreamError as e:
|
||||
return _error(str(e), code="upstream_error", status_code=502)
|
||||
finally:
|
||||
if not stream:
|
||||
out = await upstream.chat_completions(base_url, body)
|
||||
return JSONResponse(status_code=200, content=out)
|
||||
else:
|
||||
async def gen():
|
||||
async for chunk in upstream.stream_chat_completions(base_url, body):
|
||||
yield chunk
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
except UpstreamError as e:
|
||||
return _error(str(e), code="upstream_error", status_code=502)
|
||||
await lease.release()
|
||||
|
||||
return app
|
||||
|
||||
|
|
|
|||
|
|
@ -1,59 +1,66 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
class AdmissionError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
job_id: str
|
||||
submitted_at: float
|
||||
base_url: str
|
||||
body: Dict[str, Any]
|
||||
stream: bool
|
||||
class QueueSnapshot:
|
||||
queue_depth: int
|
||||
in_flight: int
|
||||
|
||||
|
||||
class DeviceLease:
|
||||
def __init__(self, queue: "DeviceQueue") -> None:
|
||||
self._queue = queue
|
||||
self._released = False
|
||||
|
||||
async def release(self) -> None:
|
||||
if self._released:
|
||||
return
|
||||
self._released = True
|
||||
await self._queue._release()
|
||||
|
||||
async def __aenter__(self) -> "DeviceLease":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.release()
|
||||
|
||||
|
||||
class DeviceQueue:
|
||||
"""Simple FIFO queue per device; one worker per device for latency predictability."""
|
||||
"""Serialize work per device and enforce a bounded number of pending requests."""
|
||||
|
||||
def __init__(self, *, client: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._client = client
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._q: asyncio.Queue[Job] = asyncio.Queue()
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self.in_flight: int = 0
|
||||
self.last_job_started_at: Optional[float] = None
|
||||
def __init__(self, *, max_pending: int = 2) -> None:
|
||||
self._max_pending = max_pending
|
||||
self._lock = asyncio.Lock()
|
||||
self._guard = asyncio.Lock()
|
||||
self._queued = 0
|
||||
self._in_flight = 0
|
||||
|
||||
def start(self) -> None:
|
||||
if self._worker_task is None:
|
||||
self._worker_task = asyncio.create_task(self._worker())
|
||||
async def acquire(self) -> DeviceLease:
|
||||
async with self._guard:
|
||||
pending = self._queued + self._in_flight
|
||||
if self._max_pending > 0 and pending >= self._max_pending:
|
||||
raise AdmissionError("device queue is full")
|
||||
self._queued += 1
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._worker_task:
|
||||
self._worker_task.cancel()
|
||||
self._worker_task = None
|
||||
await self._lock.acquire()
|
||||
|
||||
async def submit(self, job: Job) -> Job:
|
||||
await self._q.put(job)
|
||||
return job
|
||||
async with self._guard:
|
||||
self._queued -= 1
|
||||
self._in_flight += 1
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return self._q.qsize()
|
||||
return DeviceLease(self)
|
||||
|
||||
async def _worker(self) -> None:
|
||||
while True:
|
||||
job = await self._q.get()
|
||||
self.in_flight += 1
|
||||
self.last_job_started_at = time.time()
|
||||
try:
|
||||
# The actual request execution is handled by the API layer (so it can stream).
|
||||
# This worker exists mainly to serialize jobs per device and provide queue metrics.
|
||||
await asyncio.sleep(0) # placeholder
|
||||
finally:
|
||||
self.in_flight -= 1
|
||||
self._q.task_done()
|
||||
async def _release(self) -> None:
|
||||
async with self._guard:
|
||||
self._in_flight -= 1
|
||||
self._lock.release()
|
||||
|
||||
def snapshot(self) -> QueueSnapshot:
|
||||
return QueueSnapshot(queue_depth=self._queued, in_flight=self._in_flight)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ def _write_gateway_config(path: Path) -> Path:
|
|||
"role": "reviewer",
|
||||
"strategy": "round_robin",
|
||||
},
|
||||
"random-reviewer": {
|
||||
"type": "discovered",
|
||||
"openai_model_name": "random-reviewer",
|
||||
"role": "reviewer",
|
||||
"strategy": "random",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
@ -43,6 +49,7 @@ def _create_gateway_app(tmp_path: Path):
|
|||
return create_app(
|
||||
_write_gateway_config(tmp_path / "models.yaml"),
|
||||
registry_path=tmp_path / "registry.json",
|
||||
registry_stale_after_s=30.0,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -54,6 +61,10 @@ async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
|||
|
||||
def test_models_requires_client_auth(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
async def fake_get_models(base_url):
|
||||
return {"object": "list", "data": []}
|
||||
|
||||
app.state.upstream.get_models = fake_get_models
|
||||
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
|
||||
assert unauthorized.status_code == 401
|
||||
|
||||
|
|
@ -62,7 +73,8 @@ def test_models_requires_client_auth(tmp_path):
|
|||
)
|
||||
assert authorized.status_code == 200
|
||||
body = authorized.json()
|
||||
assert {item["id"] for item in body["data"]} == {"writer", "reviewer"}
|
||||
assert {item["id"] for item in body["data"]} == {"writer"}
|
||||
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"reviewer", "random-reviewer"}
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
|
|
@ -156,3 +168,184 @@ def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
|
|||
assert node["status"]["timestamp"] == 123.0
|
||||
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_models_filters_unhealthy_aliases_and_reports_them(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
|
||||
async def fake_get_models(base_url):
|
||||
if base_url == "http://127.0.0.1:8012":
|
||||
raise Exception("unexpected")
|
||||
return {"object": "list", "data": []}
|
||||
|
||||
from rolemesh_gateway.upstream import UpstreamError
|
||||
|
||||
async def wrapped_get_models(base_url):
|
||||
if base_url == "http://127.0.0.1:8012":
|
||||
raise UpstreamError("Upstream unreachable: boom")
|
||||
return await fake_get_models(base_url)
|
||||
|
||||
app.state.upstream.get_models = wrapped_get_models
|
||||
|
||||
response = asyncio.run(
|
||||
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["data"] == []
|
||||
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"writer", "reviewer", "random-reviewer"}
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_ready_returns_503_when_default_model_is_unavailable(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
|
||||
from rolemesh_gateway.upstream import UpstreamError
|
||||
|
||||
async def fake_get_models(base_url):
|
||||
raise UpstreamError("Upstream unreachable: boom")
|
||||
|
||||
app.state.upstream.get_models = fake_get_models
|
||||
|
||||
response = asyncio.run(_request(app, "GET", "/ready"))
|
||||
|
||||
assert response.status_code == 503
|
||||
body = response.json()
|
||||
assert body["status"] == "not_ready"
|
||||
assert body["default_model"] == "writer"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_ready_returns_200_when_default_model_is_available(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
|
||||
async def fake_get_models(base_url):
|
||||
if base_url == "http://127.0.0.1:8012":
|
||||
return {"object": "list", "data": []}
|
||||
raise Exception("unexpected upstream probe")
|
||||
|
||||
from rolemesh_gateway.upstream import UpstreamError
|
||||
|
||||
async def wrapped_get_models(base_url):
|
||||
if base_url == "http://127.0.0.1:8012":
|
||||
return {"object": "list", "data": []}
|
||||
raise UpstreamError("no reviewer nodes")
|
||||
|
||||
app.state.upstream.get_models = wrapped_get_models
|
||||
|
||||
response = asyncio.run(_request(app, "GET", "/ready"))
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["status"] == "ready"
|
||||
assert body["default_model"] == "writer"
|
||||
assert body["available_models"] == ["writer"]
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_stale_discovered_nodes_are_not_advertised_as_available(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
|
||||
register = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/register",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-a",
|
||||
"base_url": "http://127.0.0.1:9001",
|
||||
"roles": ["reviewer"],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert register.status_code == 200
|
||||
|
||||
node = app.state.registry.list_nodes(include_stale=True)[0]
|
||||
node.last_seen = node.last_seen - 120
|
||||
|
||||
async def fake_get_models(base_url):
|
||||
return {"object": "list", "data": []}
|
||||
|
||||
app.state.upstream.get_models = fake_get_models
|
||||
|
||||
models = asyncio.run(
|
||||
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
||||
)
|
||||
ready = asyncio.run(_request(app, "GET", "/ready"))
|
||||
|
||||
models_body = models.json()
|
||||
assert {item["id"] for item in models_body["data"]} == {"writer"}
|
||||
assert models_body["rolemesh"]["registered_nodes"][0]["stale"] is True
|
||||
unavailable = {item["alias"]: item for item in models_body["rolemesh"]["unavailable_models"]}
|
||||
assert unavailable["reviewer"]["error"] == "no_fresh_registered_nodes"
|
||||
assert unavailable["random-reviewer"]["error"] == "no_fresh_registered_nodes"
|
||||
assert ready.status_code == 200
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_discovered_model_uses_random_strategy_when_configured(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
calls = {}
|
||||
|
||||
register_a = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/register",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-a",
|
||||
"base_url": "http://127.0.0.1:9001",
|
||||
"roles": ["reviewer"],
|
||||
},
|
||||
)
|
||||
)
|
||||
register_b = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/register",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-b",
|
||||
"base_url": "http://127.0.0.1:9002",
|
||||
"roles": ["reviewer"],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert register_a.status_code == 200
|
||||
assert register_b.status_code == 200
|
||||
|
||||
import random
|
||||
|
||||
original_choice = random.choice
|
||||
random.choice = lambda candidates: candidates[-1]
|
||||
|
||||
async def fake_chat(base_url, payload):
|
||||
calls["base_url"] = base_url
|
||||
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []}
|
||||
|
||||
try:
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"x-api-key": "client-secret"},
|
||||
json={
|
||||
"model": "random-reviewer",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
random.choice = original_choice
|
||||
|
||||
assert response.status_code == 200
|
||||
assert calls["base_url"] == "http://127.0.0.1:9002"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import httpx
|
|||
|
||||
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
|
||||
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
|
||||
from rolemesh_node_agent.main import _merge_scheduler_metrics, _select_device
|
||||
from rolemesh_node_agent.scheduler import AdmissionError, DeviceQueue
|
||||
|
||||
|
||||
def _node_config(tmp_path: Path) -> NodeAgentConfig:
|
||||
|
|
@ -152,3 +154,174 @@ def test_chat_completions_returns_503_when_server_startup_fails(tmp_path):
|
|||
assert response.status_code == 503
|
||||
assert response.json()["error"]["code"] == "server_startup_error"
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
||||
|
||||
def test_select_device_prefers_already_loaded_model():
|
||||
devices = [
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
|
||||
]
|
||||
metrics = [
|
||||
DeviceMetrics(
|
||||
device=devices[0],
|
||||
loaded_model_id="other-model",
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=4.0,
|
||||
),
|
||||
DeviceMetrics(
|
||||
device=devices[1],
|
||||
loaded_model_id="target-model",
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=20.0,
|
||||
),
|
||||
]
|
||||
|
||||
picked = _select_device(devices, metrics, model_id="target-model")
|
||||
assert picked == devices[1]
|
||||
|
||||
|
||||
def test_select_device_prefers_more_free_memory_and_lower_pressure():
|
||||
devices = [
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
|
||||
]
|
||||
metrics = [
|
||||
DeviceMetrics(
|
||||
device=devices[0],
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=18.0,
|
||||
queue_depth=0,
|
||||
in_flight_jobs=0,
|
||||
utilization_pct=5.0,
|
||||
),
|
||||
DeviceMetrics(
|
||||
device=devices[1],
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=4.0,
|
||||
queue_depth=0,
|
||||
in_flight_jobs=0,
|
||||
utilization_pct=10.0,
|
||||
),
|
||||
]
|
||||
|
||||
picked = _select_device(devices, metrics, model_id="target-model")
|
||||
assert picked == devices[1]
|
||||
|
||||
|
||||
def test_chat_completions_uses_selected_device_not_first_device(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
cfg = _node_config(tmp_path)
|
||||
app = create_app(cfg)
|
||||
calls = {}
|
||||
|
||||
devices = [
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:0"),
|
||||
DeviceRef(kind="gpu", backend="cuda", id="gpu:1"),
|
||||
]
|
||||
|
||||
async def fake_discover_devices():
|
||||
return devices
|
||||
|
||||
async def fake_get_metrics():
|
||||
return [
|
||||
DeviceMetrics(
|
||||
device=devices[0],
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=20.0,
|
||||
),
|
||||
DeviceMetrics(
|
||||
device=devices[1],
|
||||
mem_total_gb=24.0,
|
||||
mem_used_gb=2.0,
|
||||
),
|
||||
]
|
||||
|
||||
async def fake_ensure_server(device, *, model_path, model_id, server_args):
|
||||
calls["device"] = device.id
|
||||
return "http://127.0.0.1:9100"
|
||||
|
||||
async def fake_chat(base_url, payload):
|
||||
return {"id": "node-cmpl", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
app.state.cuda.discover_devices = fake_discover_devices
|
||||
app.state.cuda.get_metrics = fake_get_metrics
|
||||
app.state.cuda.ensure_server = fake_ensure_server
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "planner-gguf",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert calls["device"] == "gpu:1"
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
||||
|
||||
def test_merge_scheduler_metrics_overlays_queue_state():
|
||||
device = DeviceRef(kind="gpu", backend="cuda", id="gpu:0")
|
||||
metric = DeviceMetrics(device=device)
|
||||
queue = DeviceQueue(max_pending=2)
|
||||
queue._queued = 1
|
||||
queue._in_flight = 1
|
||||
|
||||
merged = _merge_scheduler_metrics([metric], {"gpu:0": queue})
|
||||
|
||||
assert merged[0].queue_depth == 1
|
||||
assert merged[0].in_flight_jobs == 1
|
||||
|
||||
|
||||
def test_device_queue_rejects_when_full():
|
||||
queue = DeviceQueue(max_pending=1)
|
||||
queue._in_flight = 1
|
||||
|
||||
try:
|
||||
asyncio.run(queue.acquire())
|
||||
except AdmissionError as exc:
|
||||
assert "full" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected AdmissionError")
|
||||
|
||||
|
||||
def test_chat_completions_returns_429_when_device_queue_is_full(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
cfg = _node_config(tmp_path)
|
||||
cfg.max_pending_requests_per_device = 1
|
||||
app = create_app(cfg)
|
||||
|
||||
async def fake_discover_devices():
|
||||
return [DeviceRef(kind="gpu", backend="cuda", id="gpu:0")]
|
||||
|
||||
async def fake_get_metrics():
|
||||
return [DeviceMetrics(device=DeviceRef(kind="gpu", backend="cuda", id="gpu:0"))]
|
||||
|
||||
app.state.cuda.discover_devices = fake_discover_devices
|
||||
app.state.cuda.get_metrics = fake_get_metrics
|
||||
saturated = DeviceQueue(max_pending=1)
|
||||
saturated._in_flight = 1
|
||||
app.state.device_queues["gpu:0"] = saturated
|
||||
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "planner-gguf",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 429
|
||||
assert response.json()["error"]["code"] == "queue_full"
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||
|
||||
|
||||
|
|
@ -44,3 +46,55 @@ def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
|
|||
def test_registry_heartbeat_returns_none_for_unknown_node():
|
||||
registry = Registry()
|
||||
assert registry.heartbeat(NodeHeartbeat(node_id="missing")) is None
|
||||
|
||||
|
||||
def test_registry_filters_stale_nodes_from_routing(tmp_path):
|
||||
persist_path = tmp_path / "registry.json"
|
||||
registry = Registry(persist_path=persist_path, stale_after_s=0.01)
|
||||
|
||||
fresh = registry.register(
|
||||
NodeRegistration(
|
||||
node_id="fresh-node",
|
||||
base_url="http://127.0.0.1:9001",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
stale = registry.register(
|
||||
NodeRegistration(
|
||||
node_id="stale-node",
|
||||
base_url="http://127.0.0.1:9002",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
|
||||
stale.last_seen = stale.last_seen - 60
|
||||
registry._save()
|
||||
|
||||
assert registry.is_stale(stale) is True
|
||||
assert registry.is_stale(fresh) is False
|
||||
assert [node.node_id for node in registry.nodes_for_role("reviewer", include_stale=False)] == ["fresh-node"]
|
||||
assert registry.pick_node_for_role("reviewer").node_id == "fresh-node"
|
||||
|
||||
|
||||
def test_registry_supports_random_selection(monkeypatch):
|
||||
registry = Registry()
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-a",
|
||||
base_url="http://127.0.0.1:9001",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-b",
|
||||
base_url="http://127.0.0.1:9002",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(random, "choice", lambda candidates: candidates[-1])
|
||||
|
||||
picked = registry.pick_node_for_role("reviewer", strategy="random")
|
||||
assert picked is not None
|
||||
assert picked.node_id == "node-b"
|
||||
|
|
|
|||
Loading…
Reference in New Issue