Refactored to separate role and model_id references.
This commit is contained in:
parent
10926b5558
commit
00852c6c0f
40
README.md
40
README.md
|
|
@ -169,6 +169,34 @@ curl -sS -X POST http://127.0.0.1:8000/v1/chat/completions \
|
|||
|
||||
If you prefer the provided example file, copy `configs/models.example.yaml` and adjust the `proxy_url` values.
|
||||
|
||||
## Local Overrides
|
||||
|
||||
Keep tracked repo config generic and put machine-specific values in a separate local override file.
|
||||
|
||||
Examples of machine-specific values:
|
||||
|
||||
- model weight paths
|
||||
- local `llama-server` binary path
|
||||
- LAN IPs and ports
|
||||
- local API keys
|
||||
|
||||
Supported launch patterns:
|
||||
|
||||
```bash
|
||||
rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
rolemesh-node-agent --config configs/node_agent.example.yaml --config-override configs/node_agent.local.yaml
|
||||
```
|
||||
|
||||
You can also use:
|
||||
|
||||
- `ROLE_MESH_CONFIG_OVERRIDE` for the gateway
|
||||
|
||||
Tracked examples should use placeholders such as `/path/to/model-weights`, while your local override file contains the
|
||||
real values for your machine.
|
||||
|
||||
## Worked Deployment Example
|
||||
|
||||
For a concrete multi-machine example, including:
|
||||
|
|
@ -240,7 +268,17 @@ If you want machines to host backends and “register” them dynamically, run a
|
|||
(or just call the registration endpoint from your own tooling).
|
||||
|
||||
- Gateway endpoint: `POST /v1/nodes/register`
|
||||
- Node payload describes which **roles** it serves and the base URL to reach its OpenAI-compatible backend.
|
||||
- Node payload describes which concrete upstream **model IDs** it serves, which **roles** each model can satisfy,
|
||||
and the base URL to reach its OpenAI-compatible backend.
|
||||
|
||||
In discovered mode, RoleMesh now treats these as separate concepts:
|
||||
|
||||
- client alias: what the caller sends, such as `tutor`
|
||||
- role: the capability used for routing, such as `tutor` or `critic`
|
||||
- upstream model ID: the concrete model name served by the selected node, such as `qwen3-8b`
|
||||
|
||||
That means a node can advertise one served model for multiple roles, and the gateway can rewrite the forwarded
|
||||
request from a stable alias to the selected upstream model ID.
|
||||
|
||||
See: `docs/DEPLOYMENT.md` and `docs/CONFIG.md`.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
gateway:
|
||||
host: 192.168.1.100
|
||||
port: 8080
|
||||
|
||||
auth:
|
||||
client_api_keys:
|
||||
- "replace-with-local-client-key"
|
||||
node_api_keys:
|
||||
- "replace-with-local-node-key"
|
||||
|
||||
models:
|
||||
planner:
|
||||
proxy_url: http://192.168.1.101:8011
|
||||
writer:
|
||||
proxy_url: http://192.168.1.102:8012
|
||||
|
|
@ -9,17 +9,17 @@ dispatcher_node_key: "change-me-node-key-1"
|
|||
dispatcher_roles: ["planner", "coder"]
|
||||
heartbeat_interval_sec: 5
|
||||
|
||||
llama_server_bin: "llama-server"
|
||||
llama_server_bin: "/path/to/llama-server"
|
||||
llama_server_startup_timeout_s: 30
|
||||
llama_server_probe_interval_s: 0.5
|
||||
|
||||
model_roots:
|
||||
- "/models"
|
||||
- "/path/to/model-weights"
|
||||
|
||||
models:
|
||||
- model_id: "planner-gguf"
|
||||
# path is the exact GGUF file that this model_id will load when requested
|
||||
path: "/models/SomePlannerModel.Q5_K_M.gguf"
|
||||
path: "/path/to/model-weights/SomePlannerModel.Q5_K_M.gguf"
|
||||
roles: ["planner"]
|
||||
default_ctx: 8192
|
||||
# Common llama-server options can be configured as structured fields:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
node_id: "node-1-local"
|
||||
listen_host: "192.168.1.101"
|
||||
listen_port: 8091
|
||||
|
||||
dispatcher_base_url: "http://192.168.1.100:8080"
|
||||
dispatcher_node_key: "replace-with-local-node-key"
|
||||
|
||||
llama_server_bin: "/path/to/local/llama-server"
|
||||
|
||||
model_roots:
|
||||
- "/path/to/model-weights"
|
||||
|
||||
models:
|
||||
- model_id: "qwen3-8b"
|
||||
path: "/path/to/model-weights/Qwen3-8B-Q5_K_M.gguf"
|
||||
roles: ["tutor", "mentor"]
|
||||
|
|
@ -93,7 +93,7 @@ models:
|
|||
|
||||
In that setup:
|
||||
- gateway alias -> discovered role
|
||||
- discovered role -> registered node
|
||||
- discovered role -> registered node + concrete upstream model ID
|
||||
- node-agent `path` -> actual weight file on disk
|
||||
|
||||
## Proxy models
|
||||
|
|
@ -118,7 +118,7 @@ Notes:
|
|||
|
||||
## Discovered models
|
||||
|
||||
Route to a dynamically registered node that claims the role:
|
||||
Route to a dynamically registered model instance that can satisfy the role:
|
||||
|
||||
```yaml
|
||||
models:
|
||||
|
|
@ -141,11 +141,31 @@ Nodes register to `POST /v1/nodes/register`:
|
|||
{
|
||||
"node_id": "gpu-box-1",
|
||||
"base_url": "http://10.0.0.12:8014",
|
||||
"roles": ["reviewer", "planner"],
|
||||
"served_models": [
|
||||
{
|
||||
"model_id": "qwen3-8b",
|
||||
"roles": ["reviewer", "planner"],
|
||||
"meta": {"family": "Qwen3", "quant": "Q5_K_M"}
|
||||
},
|
||||
{
|
||||
"model_id": "qwen2.5-coder-14b",
|
||||
"roles": ["coder"],
|
||||
"meta": {"family": "Qwen2.5-Coder", "quant": "Q5_K_M"}
|
||||
}
|
||||
],
|
||||
"meta": {"gpu": "Tesla P40", "notes": "llama-server on GPU0"}
|
||||
}
|
||||
```
|
||||
|
||||
`served_models` is now the preferred registration schema.
|
||||
|
||||
- `model_id`: concrete model name the upstream node expects in the forwarded OpenAI request
|
||||
- `roles`: workflow roles that this model can satisfy
|
||||
- `meta`: optional operator-facing metadata
|
||||
|
||||
Legacy flat `roles` registration is still accepted for compatibility, but it is treated as a fallback where
|
||||
`model_id == role`.
|
||||
|
||||
If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key.
|
||||
|
||||
If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key.
|
||||
|
|
@ -193,3 +213,35 @@ models:
|
|||
temperature: 0.6
|
||||
max_tokens: 256
|
||||
```
|
||||
|
||||
## Base config plus local override
|
||||
|
||||
Recommended pattern:
|
||||
|
||||
- keep tracked repo config generic
|
||||
- keep machine-specific values in a separate local YAML
|
||||
- merge the local YAML at launch
|
||||
|
||||
Gateway:
|
||||
|
||||
```bash
|
||||
rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml
|
||||
```
|
||||
|
||||
Node agent:
|
||||
|
||||
```bash
|
||||
rolemesh-node-agent --config configs/node_agent.example.yaml --config-override configs/node_agent.local.yaml
|
||||
```
|
||||
|
||||
The merge is recursive for mappings:
|
||||
|
||||
- nested dictionaries are merged
|
||||
- lists and scalar values are replaced by the override file
|
||||
|
||||
This is useful for separating:
|
||||
|
||||
- real model weight paths
|
||||
- local host IPs
|
||||
- local API keys
|
||||
- local `llama-server` paths
|
||||
|
|
|
|||
|
|
@ -38,6 +38,12 @@ models:
|
|||
ROLE_MESH_CONFIG=configs/models.yaml uvicorn rolemesh_gateway.main:app --host 127.0.0.1 --port 8000
|
||||
```
|
||||
|
||||
If you want to keep local IPs and keys out of the tracked file, use an override:
|
||||
|
||||
```bash
|
||||
rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml
|
||||
```
|
||||
|
||||
### 4. Smoke test
|
||||
|
||||
```bash
|
||||
|
|
@ -129,7 +135,7 @@ This scaffold supports two patterns.
|
|||
- Update `proxy_url` entries to those LAN URLs, **or** use discovery:
|
||||
- Set model to `type: discovered` with `role: writer`, etc.
|
||||
- Choose `strategy: round_robin` or `strategy: random` per discovered alias
|
||||
- Each host registers itself with the gateway.
|
||||
- Each host registers the concrete served models it can expose for those roles.
|
||||
|
||||
### Minimal registration call
|
||||
|
||||
|
|
@ -137,9 +143,21 @@ This scaffold supports two patterns.
|
|||
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'X-RoleMesh-Node-Key: <node-key>' \
|
||||
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
|
||||
-d '{
|
||||
"node_id":"gpu-box-1",
|
||||
"base_url":"http://10.0.0.12:8012",
|
||||
"served_models":[
|
||||
{"model_id":"qwen3-8b","roles":["writer","mentor"]}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
In discovered mode, the gateway routes:
|
||||
|
||||
- client alias -> role
|
||||
- role -> eligible registered served model
|
||||
- forwarded request `model` -> selected upstream `model_id`
|
||||
|
||||
### Hardening checklist (recommended)
|
||||
|
||||
- Bind gateway to localhost by default, and explicitly expose it when needed
|
||||
|
|
|
|||
|
|
@ -195,50 +195,20 @@ On `192.168.1.102`:
|
|||
PYTHONPATH=src python -m rolemesh_node_agent.cli --config critic-node.yaml
|
||||
```
|
||||
|
||||
## 6. Register each node once
|
||||
## 6. Registration behavior
|
||||
|
||||
The current node agent sends heartbeats automatically, but registration is still a one-time explicit step.
|
||||
The current node agent now registers itself automatically on startup when `dispatcher_base_url` is set.
|
||||
It also keeps heartbeating after registration.
|
||||
|
||||
Register planner:
|
||||
For inspection, each node exposes:
|
||||
|
||||
```bash
|
||||
curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'X-RoleMesh-Node-Key: change-me-node-key' \
|
||||
-d '{
|
||||
"node_id": "gpu101-planner",
|
||||
"base_url": "http://192.168.1.101:8091",
|
||||
"roles": ["planner"]
|
||||
}'
|
||||
curl -sS http://192.168.1.101:8091/v1/node/registration
|
||||
```
|
||||
|
||||
Register writer:
|
||||
That endpoint returns the exact `served_models` payload the node agent will send to the gateway.
|
||||
|
||||
```bash
|
||||
curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'X-RoleMesh-Node-Key: change-me-node-key' \
|
||||
-d '{
|
||||
"node_id": "gpu101-writer",
|
||||
"base_url": "http://192.168.1.101:8092",
|
||||
"roles": ["writer"]
|
||||
}'
|
||||
```
|
||||
|
||||
Register critic:
|
||||
|
||||
```bash
|
||||
curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'X-RoleMesh-Node-Key: change-me-node-key' \
|
||||
-d '{
|
||||
"node_id": "gpu102-critic",
|
||||
"base_url": "http://192.168.1.102:8091",
|
||||
"roles": ["critic"]
|
||||
}'
|
||||
```
|
||||
|
||||
After that, the heartbeat loop on each node agent keeps the registry entry fresh.
|
||||
Manual `POST /v1/nodes/register` is still supported, but it is now mainly useful for custom tooling or debugging.
|
||||
|
||||
## 7. Verify the topology
|
||||
|
||||
|
|
@ -267,6 +237,12 @@ That endpoint shows:
|
|||
- device metrics
|
||||
- queue depth and in-flight work
|
||||
|
||||
You can also inspect the registration payload directly:
|
||||
|
||||
```bash
|
||||
curl -sS http://192.168.1.101:8091/v1/node/registration
|
||||
```
|
||||
|
||||
## 8. Send requests by role
|
||||
|
||||
Planner request through the gateway:
|
||||
|
|
|
|||
|
|
@ -90,6 +90,21 @@ pip install -e .
|
|||
rolemesh-node-agent --config configs/node_agent.example.yaml
|
||||
```
|
||||
|
||||
To keep local paths and host-specific values out of the tracked config, use an override file:
|
||||
|
||||
```bash
|
||||
rolemesh-node-agent \
|
||||
--config configs/node_agent.example.yaml \
|
||||
--config-override configs/node_agent.local.yaml
|
||||
```
|
||||
|
||||
Tracked base config should contain placeholders such as:
|
||||
|
||||
- `/path/to/model-weights`
|
||||
- `/path/to/llama-server`
|
||||
|
||||
Your local override should contain the real machine-specific values.
|
||||
|
||||
### Startup timing guards
|
||||
|
||||
Two config knobs control how long the node agent waits for a managed `llama-server` to become ready:
|
||||
|
|
@ -108,11 +123,19 @@ The readiness probe checks the managed server's local `GET /health` and `GET /v1
|
|||
|
||||
## Registering
|
||||
|
||||
If `dispatcher_base_url` is set in the node-agent config, the node agent will periodically call:
|
||||
If `dispatcher_base_url` is set in the node-agent config, the node agent will:
|
||||
|
||||
- register itself on startup via `POST <dispatcher>/v1/nodes/register`
|
||||
- `POST <dispatcher>/v1/nodes/heartbeat` with latest device metrics.
|
||||
|
||||
Registration is currently manual from the node side (or can be added as a startup step).
|
||||
The registration payload is derived from local `models[]` and advertises `served_models`, where each local
|
||||
`model_id` lists the roles it can satisfy.
|
||||
|
||||
For operator inspection, the node agent also exposes:
|
||||
|
||||
- `GET /v1/node/registration`
|
||||
|
||||
That endpoint returns the exact registration payload the node would send to the dispatcher.
|
||||
|
||||
### Binding
|
||||
|
||||
|
|
|
|||
|
|
@ -52,8 +52,8 @@ async def _model_status(
|
|||
error = "no_fresh_registered_nodes"
|
||||
return {"alias": alias, "available": False, "role": entry.role, "error": error}
|
||||
|
||||
for node in matching_nodes:
|
||||
base_url = str(node.base_url).rstrip("/")
|
||||
for candidate in registry.candidates_for_role(entry.role, include_stale=False):
|
||||
base_url = str(candidate.base_url).rstrip("/")
|
||||
try:
|
||||
await upstream.get_models(base_url)
|
||||
return {
|
||||
|
|
@ -61,7 +61,8 @@ async def _model_status(
|
|||
"available": True,
|
||||
"role": entry.role,
|
||||
"base_url": base_url,
|
||||
"node_id": node.node_id,
|
||||
"node_id": candidate.node_id,
|
||||
"upstream_model_id": candidate.model_id,
|
||||
}
|
||||
except UpstreamError:
|
||||
continue
|
||||
|
|
@ -160,15 +161,16 @@ async def chat_completions(request: Request, _=Depends(require_client_auth)) ->
|
|||
if isinstance(entry, ProxyModel):
|
||||
base_url = str(entry.proxy_url).rstrip("/")
|
||||
elif isinstance(entry, DiscoveredModel):
|
||||
node = registry.pick_node_for_role(entry.role, strategy=entry.strategy)
|
||||
if not node:
|
||||
candidate = registry.pick_candidate_for_role(entry.role, strategy=entry.strategy)
|
||||
if not candidate:
|
||||
return _openai_error(
|
||||
f"No registered nodes available for role '{entry.role}'. "
|
||||
f"Register a node via POST /v1/nodes/register, or use proxy mode.",
|
||||
code="no_upstream",
|
||||
status_code=503,
|
||||
)
|
||||
base_url = str(node.base_url).rstrip("/")
|
||||
base_url = str(candidate.base_url).rstrip("/")
|
||||
body["model"] = candidate.model_id
|
||||
else:
|
||||
return _openai_error("Invalid model configuration.", code="bad_config", status_code=500)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ from rolemesh_gateway.main import create_app
|
|||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="RoleMesh Gateway")
|
||||
p.add_argument("--config", required=True, help="Path to gateway YAML config.")
|
||||
p.add_argument("--config-override", help="Optional local override YAML merged over the base config.")
|
||||
p.add_argument("--host", default="127.0.0.1")
|
||||
p.add_argument("--port", type=int, default=8080)
|
||||
args = p.parse_args()
|
||||
|
||||
app = create_app(Path(args.config))
|
||||
app = create_app(Path(args.config), config_override_path=Path(args.config_override) if args.config_override else None)
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,9 +54,28 @@ class Config(BaseModel):
|
|||
models: Dict[str, ModelEntry] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def load_config(path: str | Path) -> Config:
|
||||
path = Path(path)
|
||||
def _read_yaml_mapping(path: Path) -> dict:
|
||||
data = yaml.safe_load(path.read_text())
|
||||
if data is None:
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Invalid config file at {path}: expected a YAML mapping at top level.")
|
||||
return data
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
merged = dict(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def load_config(path: str | Path, override_path: str | Path | None = None) -> Config:
|
||||
path = Path(path)
|
||||
data = _read_yaml_mapping(path)
|
||||
if override_path is not None:
|
||||
data = _deep_merge(data, _read_yaml_mapping(Path(override_path)))
|
||||
return Config.model_validate(data)
|
||||
|
|
|
|||
|
|
@ -26,11 +26,13 @@ def _get_logger() -> logging.Logger:
|
|||
|
||||
def create_app(
|
||||
config_path: str | Path | None = None,
|
||||
config_override_path: str | Path | None = None,
|
||||
registry_path: str | Path | None = None,
|
||||
registry_stale_after_s: float | None = None,
|
||||
) -> FastAPI:
|
||||
cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml")
|
||||
cfg = load_config(cfg_path)
|
||||
cfg_override_path = config_override_path or os.environ.get("ROLE_MESH_CONFIG_OVERRIDE")
|
||||
cfg = load_config(cfg_path, override_path=cfg_override_path)
|
||||
|
||||
resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json")
|
||||
resolved_registry_stale_after_s = registry_stale_after_s
|
||||
|
|
|
|||
|
|
@ -9,10 +9,17 @@ from typing import Any, Dict, List, Optional
|
|||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class ServedModel(BaseModel):
|
||||
model_id: str
|
||||
roles: List[str] = Field(default_factory=list)
|
||||
meta: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class NodeRegistration(BaseModel):
|
||||
node_id: str
|
||||
base_url: HttpUrl # OpenAI-compatible upstream base, e.g. http://10.0.0.12:8011
|
||||
roles: List[str] # roles served by this node, e.g. ["planner", "writer"]
|
||||
served_models: List[ServedModel] = Field(default_factory=list)
|
||||
roles: List[str] = Field(default_factory=list) # legacy compatibility input
|
||||
meta: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
@ -28,13 +35,22 @@ class NodeHeartbeat(BaseModel):
|
|||
class RegisteredNode(BaseModel):
|
||||
node_id: str
|
||||
base_url: HttpUrl
|
||||
roles: List[str]
|
||||
served_models: List[ServedModel] = Field(default_factory=list)
|
||||
roles: List[str] = Field(default_factory=list)
|
||||
meta: Dict[str, str] = Field(default_factory=dict)
|
||||
status: Dict[str, Any] = Field(default_factory=dict)
|
||||
registered_at: float = Field(default_factory=lambda: time.time())
|
||||
last_seen: float = Field(default_factory=lambda: time.time())
|
||||
|
||||
|
||||
class RoleCandidate(BaseModel):
|
||||
node_id: str
|
||||
base_url: HttpUrl
|
||||
role: str
|
||||
model_id: str
|
||||
model_meta: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Registry:
|
||||
"""
|
||||
Minimal in-memory registry with optional JSON persistence.
|
||||
|
|
@ -60,7 +76,15 @@ class Registry:
|
|||
try:
|
||||
raw = json.loads(self._persist_path.read_text())
|
||||
for node_id, node_data in raw.get("nodes", {}).items():
|
||||
self._nodes[node_id] = RegisteredNode.model_validate(node_data)
|
||||
node = RegisteredNode.model_validate(node_data)
|
||||
if not node.served_models and node.roles:
|
||||
node.served_models = [
|
||||
ServedModel(model_id=role, roles=[role])
|
||||
for role in node.roles
|
||||
]
|
||||
if not node.roles and node.served_models:
|
||||
node.roles = self._derive_roles(node.served_models, [])
|
||||
self._nodes[node_id] = node
|
||||
self._rr_counters = dict(raw.get("rr_counters", {}))
|
||||
except Exception:
|
||||
# If persistence is corrupted, start empty (do not crash the gateway).
|
||||
|
|
@ -77,11 +101,31 @@ class Registry:
|
|||
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._persist_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
|
||||
|
||||
@staticmethod
|
||||
def _derive_served_models(reg: NodeRegistration) -> List[ServedModel]:
|
||||
if reg.served_models:
|
||||
return reg.served_models
|
||||
return [ServedModel(model_id=role, roles=[role]) for role in reg.roles]
|
||||
|
||||
@staticmethod
|
||||
def _derive_roles(served_models: List[ServedModel], legacy_roles: List[str]) -> List[str]:
|
||||
roles: list[str] = []
|
||||
for served_model in served_models:
|
||||
for role in served_model.roles:
|
||||
if role not in roles:
|
||||
roles.append(role)
|
||||
for role in legacy_roles:
|
||||
if role not in roles:
|
||||
roles.append(role)
|
||||
return roles
|
||||
|
||||
def register(self, reg: NodeRegistration) -> RegisteredNode:
|
||||
served_models = self._derive_served_models(reg)
|
||||
node = RegisteredNode(
|
||||
node_id=reg.node_id,
|
||||
base_url=reg.base_url,
|
||||
roles=reg.roles,
|
||||
served_models=served_models,
|
||||
roles=self._derive_roles(served_models, reg.roles),
|
||||
meta=reg.meta,
|
||||
last_seen=time.time(),
|
||||
)
|
||||
|
|
@ -119,6 +163,35 @@ class Registry:
|
|||
nodes = self.list_nodes(include_stale=include_stale)
|
||||
return [node for node in nodes if role in node.roles]
|
||||
|
||||
def candidates_for_role(self, role: str, *, include_stale: bool = False) -> List[RoleCandidate]:
|
||||
nodes = self.list_nodes(include_stale=include_stale)
|
||||
candidates: list[RoleCandidate] = []
|
||||
for node in nodes:
|
||||
for served_model in node.served_models:
|
||||
if role not in served_model.roles:
|
||||
continue
|
||||
candidates.append(
|
||||
RoleCandidate(
|
||||
node_id=node.node_id,
|
||||
base_url=node.base_url,
|
||||
role=role,
|
||||
model_id=served_model.model_id,
|
||||
model_meta=served_model.meta,
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
def pick_candidate_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RoleCandidate]:
|
||||
candidates = self.candidates_for_role(role, include_stale=False)
|
||||
if not candidates:
|
||||
return None
|
||||
if strategy == "random":
|
||||
return random.choice(candidates)
|
||||
idx = self._rr_counters.get(role, 0) % len(candidates)
|
||||
self._rr_counters[role] = idx + 1
|
||||
self._save()
|
||||
return candidates[idx]
|
||||
|
||||
def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]:
|
||||
candidates = self.nodes_for_role(role, include_stale=False)
|
||||
if not candidates:
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from .config import NodeAgentConfig
|
||||
from .config import load_config
|
||||
from .main import create_app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="RoleMesh Node Agent")
|
||||
p.add_argument("--config", required=True, help="Path to node-agent YAML config.")
|
||||
p.add_argument("--config-override", help="Optional local override YAML merged over the base config.")
|
||||
args = p.parse_args()
|
||||
|
||||
cfg_path = Path(args.config)
|
||||
cfg = NodeAgentConfig.model_validate(yaml.safe_load(cfg_path.read_text()))
|
||||
cfg = load_config(args.config, override_path=args.config_override)
|
||||
app = create_app(cfg)
|
||||
uvicorn.run(app, host=cfg.listen_host, port=cfg.listen_port)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
|
|
@ -51,3 +52,30 @@ class NodeAgentConfig(BaseModel):
|
|||
llama_server_startup_timeout_s: float = 30.0
|
||||
llama_server_probe_interval_s: float = 0.5
|
||||
max_pending_requests_per_device: int = 2
|
||||
|
||||
|
||||
def _read_yaml_mapping(path: Path) -> dict:
|
||||
data = yaml.safe_load(path.read_text())
|
||||
if data is None:
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Invalid config file at {path}: expected a YAML mapping at top level.")
|
||||
return data
|
||||
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
merged = dict(base)
|
||||
for key, value in override.items():
|
||||
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
||||
merged[key] = _deep_merge(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def load_config(path: str | Path, override_path: str | Path | None = None) -> NodeAgentConfig:
|
||||
path = Path(path)
|
||||
data = _read_yaml_mapping(path)
|
||||
if override_path is not None:
|
||||
data = _deep_merge(data, _read_yaml_mapping(Path(override_path)))
|
||||
return NodeAgentConfig.model_validate(data)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,46 @@ def _merge_scheduler_metrics(
|
|||
return out
|
||||
|
||||
|
||||
def _effective_served_models(cfg: NodeAgentConfig) -> list[dict[str, Any]]:
|
||||
allowed_roles = set(cfg.dispatcher_roles)
|
||||
served_models: list[dict[str, Any]] = []
|
||||
for model in cfg.models:
|
||||
roles = list(model.roles)
|
||||
if allowed_roles:
|
||||
roles = [role for role in roles if role in allowed_roles]
|
||||
if not roles:
|
||||
continue
|
||||
meta: dict[str, str] = {}
|
||||
if model.alias:
|
||||
meta["alias"] = model.alias
|
||||
served_models.append(
|
||||
{
|
||||
"model_id": model.model_id,
|
||||
"roles": roles,
|
||||
"meta": meta,
|
||||
}
|
||||
)
|
||||
return served_models
|
||||
|
||||
|
||||
def _registration_payload(cfg: NodeAgentConfig) -> dict[str, Any]:
|
||||
served_models = _effective_served_models(cfg)
|
||||
roles: list[str] = []
|
||||
for served_model in served_models:
|
||||
for role in served_model["roles"]:
|
||||
if role not in roles:
|
||||
roles.append(role)
|
||||
return {
|
||||
"node_id": cfg.node_id,
|
||||
"base_url": f"http://{cfg.listen_host}:{cfg.listen_port}",
|
||||
"served_models": served_models,
|
||||
"roles": roles,
|
||||
"meta": {
|
||||
"node_agent": "rolemesh",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _select_device(
|
||||
devices: Iterable[DeviceRef],
|
||||
metrics: Iterable[DeviceMetrics],
|
||||
|
|
@ -86,7 +126,8 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
heartbeat_task: asyncio.Task[None] | None = None
|
||||
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||
if cfg.dispatcher_base_url:
|
||||
await _register_with_dispatcher(app)
|
||||
heartbeat_task = asyncio.create_task(_heartbeat_loop(app))
|
||||
try:
|
||||
yield
|
||||
|
|
@ -116,6 +157,10 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
|||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
||||
|
||||
@app.get("/v1/node/registration")
|
||||
async def registration_payload() -> Dict[str, Any]:
|
||||
return _registration_payload(cfg)
|
||||
|
||||
@app.get("/v1/node/inventory")
|
||||
async def inventory() -> Dict[str, Any]:
|
||||
devices = await app.state.cuda.discover_devices()
|
||||
|
|
@ -221,3 +266,18 @@ async def _heartbeat_loop(app: FastAPI) -> None:
|
|||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
||||
|
||||
|
||||
async def _register_with_dispatcher(app: FastAPI) -> None:
|
||||
cfg: NodeAgentConfig = app.state.cfg
|
||||
http: httpx.AsyncClient = app.state.http
|
||||
if not cfg.dispatcher_base_url:
|
||||
return
|
||||
try:
|
||||
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/register"
|
||||
headers = {}
|
||||
if cfg.dispatcher_node_key:
|
||||
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
|
||||
await http.post(url, json=_registration_payload(cfg), headers=headers)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from rolemesh_gateway.config import load_config as load_gateway_config
|
||||
from rolemesh_node_agent.config import load_config as load_node_config
|
||||
|
||||
|
||||
def test_gateway_config_override_merges_nested_mappings(tmp_path: Path):
|
||||
base = tmp_path / "models.yaml"
|
||||
override = tmp_path / "models.local.yaml"
|
||||
base.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"version": 1,
|
||||
"default_model": "writer",
|
||||
"auth": {
|
||||
"client_api_keys": ["base-client-key"],
|
||||
"node_api_keys": ["base-node-key"],
|
||||
},
|
||||
"models": {
|
||||
"writer": {
|
||||
"type": "proxy",
|
||||
"openai_model_name": "writer",
|
||||
"proxy_url": "http://127.0.0.1:8012",
|
||||
"defaults": {"temperature": 0.6, "max_tokens": 256},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
override.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"auth": {
|
||||
"client_api_keys": ["local-client-key"],
|
||||
},
|
||||
"models": {
|
||||
"writer": {
|
||||
"proxy_url": "http://192.168.1.50:8012",
|
||||
"defaults": {"temperature": 0.2},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
cfg = load_gateway_config(base, override_path=override)
|
||||
|
||||
assert cfg.auth.client_api_keys == ["local-client-key"]
|
||||
assert cfg.auth.node_api_keys == ["base-node-key"]
|
||||
assert str(cfg.models["writer"].proxy_url) == "http://192.168.1.50:8012/"
|
||||
assert cfg.models["writer"].defaults == {"temperature": 0.2, "max_tokens": 256}
|
||||
|
||||
|
||||
def test_node_agent_config_override_replaces_local_machine_paths(tmp_path: Path):
|
||||
base = tmp_path / "node_agent.yaml"
|
||||
override = tmp_path / "node_agent.local.yaml"
|
||||
base.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"node_id": "node-generic",
|
||||
"listen_host": "0.0.0.0",
|
||||
"listen_port": 8091,
|
||||
"dispatcher_base_url": "http://10.0.0.10:8080",
|
||||
"llama_server_bin": "/path/to/llama-server",
|
||||
"model_roots": ["/path/to/model-weights"],
|
||||
"models": [
|
||||
{
|
||||
"model_id": "qwen3-8b",
|
||||
"path": "/path/to/model-weights/Qwen3-8B-Q5_K_M.gguf",
|
||||
"roles": ["tutor"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
override.write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"listen_host": "192.168.1.101",
|
||||
"llama_server_bin": "/home/netuser/bin/llama.cpp/build/bin/llama-server",
|
||||
"model_roots": ["/home/netuser/bin/models/llm"],
|
||||
"models": [
|
||||
{
|
||||
"model_id": "qwen3-8b",
|
||||
"path": "/home/netuser/bin/models/llm/Qwen3-8B-Q5_K_M.gguf",
|
||||
"roles": ["tutor", "mentor"],
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
cfg = load_node_config(base, override_path=override)
|
||||
|
||||
assert cfg.listen_host == "192.168.1.101"
|
||||
assert cfg.llama_server_bin == "/home/netuser/bin/llama.cpp/build/bin/llama-server"
|
||||
assert [str(path) for path in cfg.model_roots] == ["/home/netuser/bin/models/llm"]
|
||||
assert str(cfg.models[0].path) == "/home/netuser/bin/models/llm/Qwen3-8B-Q5_K_M.gguf"
|
||||
assert cfg.models[0].roles == ["tutor", "mentor"]
|
||||
|
|
@ -325,6 +325,7 @@ def test_discovered_model_uses_random_strategy_when_configured(tmp_path):
|
|||
|
||||
async def fake_chat(base_url, payload):
|
||||
calls["base_url"] = base_url
|
||||
calls["payload"] = payload
|
||||
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
|
|
@ -348,4 +349,57 @@ def test_discovered_model_uses_random_strategy_when_configured(tmp_path):
|
|||
|
||||
assert response.status_code == 200
|
||||
assert calls["base_url"] == "http://127.0.0.1:9002"
|
||||
assert calls["payload"]["model"] == "reviewer"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
||||
|
||||
def test_discovered_model_rewrites_alias_to_upstream_model_id(tmp_path):
|
||||
app = _create_gateway_app(tmp_path)
|
||||
calls = {}
|
||||
|
||||
register = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/nodes/register",
|
||||
headers={"x-rolemesh-node-key": "node-secret"},
|
||||
json={
|
||||
"node_id": "node-a",
|
||||
"base_url": "http://127.0.0.1:9001",
|
||||
"served_models": [
|
||||
{
|
||||
"model_id": "qwen3-8b",
|
||||
"roles": ["reviewer"],
|
||||
"meta": {"family": "Qwen3"},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert register.status_code == 200
|
||||
|
||||
async def fake_chat(base_url, payload):
|
||||
calls["base_url"] = base_url
|
||||
calls["payload"] = payload
|
||||
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
||||
|
||||
app.state.upstream.chat_completions = fake_chat
|
||||
app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []}
|
||||
|
||||
response = asyncio.run(
|
||||
_request(
|
||||
app,
|
||||
"POST",
|
||||
"/v1/chat/completions",
|
||||
headers={"x-api-key": "client-secret"},
|
||||
json={
|
||||
"model": "reviewer",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert calls["base_url"] == "http://127.0.0.1:9001"
|
||||
assert calls["payload"]["model"] == "qwen3-8b"
|
||||
asyncio.run(app.state.upstream.close())
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import httpx
|
|||
from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef
|
||||
from rolemesh_node_agent.adapters.cuda import CudaAdapter
|
||||
from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig
|
||||
from rolemesh_node_agent.main import _merge_scheduler_metrics, _select_device
|
||||
from rolemesh_node_agent.main import _merge_scheduler_metrics, _registration_payload, _select_device
|
||||
from rolemesh_node_agent.scheduler import AdmissionError, DeviceQueue
|
||||
|
||||
|
||||
|
|
@ -59,6 +59,74 @@ def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path):
|
|||
asyncio.run(app.state.http.aclose())
|
||||
|
||||
|
||||
def test_registration_payload_uses_served_models_from_local_catalog(tmp_path):
|
||||
model_path = tmp_path / "model.gguf"
|
||||
model_path.write_bytes(b"GGUF")
|
||||
cfg = NodeAgentConfig(
|
||||
node_id="node-1",
|
||||
listen_host="192.168.1.101",
|
||||
listen_port=8091,
|
||||
dispatcher_roles=["mentor", "code_tutor"],
|
||||
models=[
|
||||
ModelEntry(
|
||||
model_id="qwen3-8b",
|
||||
path=model_path,
|
||||
roles=["mentor", "tutor"],
|
||||
alias="mentor",
|
||||
),
|
||||
ModelEntry(
|
||||
model_id="qwen2.5-coder-14b",
|
||||
path=model_path,
|
||||
roles=["code_tutor"],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
payload = _registration_payload(cfg)
|
||||
|
||||
assert payload["node_id"] == "node-1"
|
||||
assert payload["base_url"] == "http://192.168.1.101:8091"
|
||||
assert payload["roles"] == ["mentor", "code_tutor"]
|
||||
assert payload["served_models"] == [
|
||||
{
|
||||
"model_id": "qwen3-8b",
|
||||
"roles": ["mentor"],
|
||||
"meta": {"alias": "mentor"},
|
||||
},
|
||||
{
|
||||
"model_id": "qwen2.5-coder-14b",
|
||||
"roles": ["code_tutor"],
|
||||
"meta": {},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_registration_endpoint_returns_computed_payload(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
model_path = tmp_path / "model.gguf"
|
||||
model_path.write_bytes(b"GGUF")
|
||||
cfg = NodeAgentConfig(
|
||||
node_id="node-1",
|
||||
listen_host="192.168.1.101",
|
||||
listen_port=8091,
|
||||
models=[
|
||||
ModelEntry(model_id="qwen3-8b", path=model_path, roles=["mentor", "tutor"]),
|
||||
],
|
||||
)
|
||||
app = create_app(cfg)
|
||||
|
||||
response = asyncio.run(_request(app, "GET", "/v1/node/registration"))
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["node_id"] == "node-1"
|
||||
assert body["base_url"] == "http://192.168.1.101:8091"
|
||||
assert body["served_models"][0]["model_id"] == "qwen3-8b"
|
||||
assert body["served_models"][0]["roles"] == ["mentor", "tutor"]
|
||||
asyncio.run(app.state.http.aclose())
|
||||
|
||||
|
||||
def test_chat_completions_routes_to_local_server_and_streams(tmp_path):
|
||||
from rolemesh_node_agent.main import create_app
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import random
|
||||
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry, ServedModel
|
||||
|
||||
|
||||
def test_registry_persists_round_robin_and_heartbeat_state(tmp_path):
|
||||
|
|
@ -98,3 +98,51 @@ def test_registry_supports_random_selection(monkeypatch):
|
|||
picked = registry.pick_node_for_role("reviewer", strategy="random")
|
||||
assert picked is not None
|
||||
assert picked.node_id == "node-b"
|
||||
|
||||
|
||||
def test_registry_supports_served_models_and_candidate_selection():
|
||||
registry = Registry()
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-a",
|
||||
base_url="http://127.0.0.1:9001",
|
||||
served_models=[
|
||||
ServedModel(model_id="qwen3-8b", roles=["tutor", "mentor"]),
|
||||
ServedModel(model_id="qwen2.5-coder-14b", roles=["code_tutor"]),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
tutor_candidates = registry.candidates_for_role("tutor")
|
||||
assert len(tutor_candidates) == 1
|
||||
assert tutor_candidates[0].node_id == "node-a"
|
||||
assert tutor_candidates[0].model_id == "qwen3-8b"
|
||||
|
||||
mentor_candidates = registry.candidates_for_role("mentor")
|
||||
assert len(mentor_candidates) == 1
|
||||
assert mentor_candidates[0].model_id == "qwen3-8b"
|
||||
|
||||
code_candidates = registry.candidates_for_role("code_tutor")
|
||||
assert len(code_candidates) == 1
|
||||
assert code_candidates[0].model_id == "qwen2.5-coder-14b"
|
||||
|
||||
picked = registry.pick_candidate_for_role("mentor")
|
||||
assert picked is not None
|
||||
assert picked.model_id == "qwen3-8b"
|
||||
|
||||
|
||||
def test_registry_loads_legacy_roles_as_served_models(tmp_path):
|
||||
persist_path = tmp_path / "registry.json"
|
||||
registry = Registry(persist_path=persist_path)
|
||||
registry.register(
|
||||
NodeRegistration(
|
||||
node_id="node-a",
|
||||
base_url="http://127.0.0.1:9001",
|
||||
roles=["reviewer"],
|
||||
)
|
||||
)
|
||||
|
||||
reloaded = Registry(persist_path=persist_path)
|
||||
candidates = reloaded.candidates_for_role("reviewer")
|
||||
assert len(candidates) == 1
|
||||
assert candidates[0].model_id == "reviewer"
|
||||
|
|
|
|||
Loading…
Reference in New Issue