diff --git a/README.md b/README.md index 273f345..2be8296 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,34 @@ curl -sS -X POST http://127.0.0.1:8000/v1/chat/completions \ If you prefer the provided example file, copy `configs/models.example.yaml` and adjust the `proxy_url` values. +## Local Overrides + +Keep tracked repo config generic and put machine-specific values in a separate local override file. + +Examples of machine-specific values: + +- model weight paths +- local `llama-server` binary path +- LAN IPs and ports +- local API keys + +Supported launch patterns: + +```bash +rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml +``` + +```bash +rolemesh-node-agent --config configs/node_agent.example.yaml --config-override configs/node_agent.local.yaml +``` + +You can also use: + +- `ROLE_MESH_CONFIG_OVERRIDE` for the gateway + +Tracked examples should use placeholders such as `/path/to/model-weights`, while your local override file contains the +real values for your machine. + ## Worked Deployment Example For a concrete multi-machine example, including: @@ -240,7 +268,17 @@ If you want machines to host backends and “register” them dynamically, run a (or just call the registration endpoint from your own tooling). - Gateway endpoint: `POST /v1/nodes/register` -- Node payload describes which **roles** it serves and the base URL to reach its OpenAI-compatible backend. +- Node payload describes which concrete upstream **model IDs** it serves, which **roles** each model can satisfy, + and the base URL to reach its OpenAI-compatible backend. + +In discovered mode, RoleMesh now treats these as separate concepts: + +- client alias: what the caller sends, such as `tutor` +- role: the capability used for routing, such as `tutor` or `critic` +- upstream model ID: the concrete model name served by the selected node, such as `qwen3-8b` + +That means a node can advertise one served model for multiple roles, and the gateway can rewrite the forwarded +request from a stable alias to the selected upstream model ID. See: `docs/DEPLOYMENT.md` and `docs/CONFIG.md`. diff --git a/configs/models.local.example.yaml b/configs/models.local.example.yaml new file mode 100644 index 0000000..e4db944 --- /dev/null +++ b/configs/models.local.example.yaml @@ -0,0 +1,15 @@ +gateway: + host: 192.168.1.100 + port: 8080 + +auth: + client_api_keys: + - "replace-with-local-client-key" + node_api_keys: + - "replace-with-local-node-key" + +models: + planner: + proxy_url: http://192.168.1.101:8011 + writer: + proxy_url: http://192.168.1.102:8012 diff --git a/configs/node_agent.example.yaml b/configs/node_agent.example.yaml index 11b7517..f2349c2 100644 --- a/configs/node_agent.example.yaml +++ b/configs/node_agent.example.yaml @@ -9,17 +9,17 @@ dispatcher_node_key: "change-me-node-key-1" dispatcher_roles: ["planner", "coder"] heartbeat_interval_sec: 5 -llama_server_bin: "llama-server" +llama_server_bin: "/path/to/llama-server" llama_server_startup_timeout_s: 30 llama_server_probe_interval_s: 0.5 model_roots: - - "/models" + - "/path/to/model-weights" models: - model_id: "planner-gguf" # path is the exact GGUF file that this model_id will load when requested - path: "/models/SomePlannerModel.Q5_K_M.gguf" + path: "/path/to/model-weights/SomePlannerModel.Q5_K_M.gguf" roles: ["planner"] default_ctx: 8192 # Common llama-server options can be configured as structured fields: diff --git a/configs/node_agent.local.example.yaml b/configs/node_agent.local.example.yaml new file mode 100644 index 0000000..a6eb938 --- /dev/null +++ b/configs/node_agent.local.example.yaml @@ -0,0 +1,16 @@ +node_id: "node-1-local" +listen_host: "192.168.1.101" +listen_port: 8091 + +dispatcher_base_url: "http://192.168.1.100:8080" +dispatcher_node_key: "replace-with-local-node-key" + +llama_server_bin: "/path/to/local/llama-server" + +model_roots: + - "/path/to/model-weights" + +models: + - model_id: "qwen3-8b" + path: "/path/to/model-weights/Qwen3-8B-Q5_K_M.gguf" + roles: ["tutor", "mentor"] diff --git a/docs/CONFIG.md b/docs/CONFIG.md index 4ad649f..273e49f 100644 --- a/docs/CONFIG.md +++ b/docs/CONFIG.md @@ -93,7 +93,7 @@ models: In that setup: - gateway alias -> discovered role -- discovered role -> registered node +- discovered role -> registered node + concrete upstream model ID - node-agent `path` -> actual weight file on disk ## Proxy models @@ -118,7 +118,7 @@ Notes: ## Discovered models -Route to a dynamically registered node that claims the role: +Route to a dynamically registered model instance that can satisfy the role: ```yaml models: @@ -141,11 +141,31 @@ Nodes register to `POST /v1/nodes/register`: { "node_id": "gpu-box-1", "base_url": "http://10.0.0.12:8014", - "roles": ["reviewer", "planner"], + "served_models": [ + { + "model_id": "qwen3-8b", + "roles": ["reviewer", "planner"], + "meta": {"family": "Qwen3", "quant": "Q5_K_M"} + }, + { + "model_id": "qwen2.5-coder-14b", + "roles": ["coder"], + "meta": {"family": "Qwen2.5-Coder", "quant": "Q5_K_M"} + } + ], "meta": {"gpu": "Tesla P40", "notes": "llama-server on GPU0"} } ``` +`served_models` is now the preferred registration schema. + +- `model_id`: concrete model name the upstream node expects in the forwarded OpenAI request +- `roles`: workflow roles that this model can satisfy +- `meta`: optional operator-facing metadata + +Legacy flat `roles` registration is still accepted for compatibility, but it is treated as a fallback where +`model_id == role`. + If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key. If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key. @@ -193,3 +213,35 @@ models: temperature: 0.6 max_tokens: 256 ``` + +## Base config plus local override + +Recommended pattern: + +- keep tracked repo config generic +- keep machine-specific values in a separate local YAML +- merge the local YAML at launch + +Gateway: + +```bash +rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml +``` + +Node agent: + +```bash +rolemesh-node-agent --config configs/node_agent.example.yaml --config-override configs/node_agent.local.yaml +``` + +The merge is recursive for mappings: + +- nested dictionaries are merged +- lists and scalar values are replaced by the override file + +This is useful for separating: + +- real model weight paths +- local host IPs +- local API keys +- local `llama-server` paths diff --git a/docs/DEPLOYMENT.md b/docs/DEPLOYMENT.md index 1b0ba29..ce0589b 100644 --- a/docs/DEPLOYMENT.md +++ b/docs/DEPLOYMENT.md @@ -38,6 +38,12 @@ models: ROLE_MESH_CONFIG=configs/models.yaml uvicorn rolemesh_gateway.main:app --host 127.0.0.1 --port 8000 ``` +If you want to keep local IPs and keys out of the tracked file, use an override: + +```bash +rolemesh-gateway --config configs/models.example.yaml --config-override configs/models.local.yaml +``` + ### 4. Smoke test ```bash @@ -129,7 +135,7 @@ This scaffold supports two patterns. - Update `proxy_url` entries to those LAN URLs, **or** use discovery: - Set model to `type: discovered` with `role: writer`, etc. - Choose `strategy: round_robin` or `strategy: random` per discovered alias - - Each host registers itself with the gateway. + - Each host registers the concrete served models it can expose for those roles. ### Minimal registration call @@ -137,9 +143,21 @@ This scaffold supports two patterns. curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \ -H 'Content-Type: application/json' \ -H 'X-RoleMesh-Node-Key: ' \ - -d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}' + -d '{ + "node_id":"gpu-box-1", + "base_url":"http://10.0.0.12:8012", + "served_models":[ + {"model_id":"qwen3-8b","roles":["writer","mentor"]} + ] + }' ``` +In discovered mode, the gateway routes: + +- client alias -> role +- role -> eligible registered served model +- forwarded request `model` -> selected upstream `model_id` + ### Hardening checklist (recommended) - Bind gateway to localhost by default, and explicitly expose it when needed diff --git a/docs/EXAMPLE_MULTI_NODE.md b/docs/EXAMPLE_MULTI_NODE.md index c8443ac..e9ecf79 100644 --- a/docs/EXAMPLE_MULTI_NODE.md +++ b/docs/EXAMPLE_MULTI_NODE.md @@ -195,50 +195,20 @@ On `192.168.1.102`: PYTHONPATH=src python -m rolemesh_node_agent.cli --config critic-node.yaml ``` -## 6. Register each node once +## 6. Registration behavior -The current node agent sends heartbeats automatically, but registration is still a one-time explicit step. +The current node agent now registers itself automatically on startup when `dispatcher_base_url` is set. +It also keeps heartbeating after registration. -Register planner: +For inspection, each node exposes: ```bash -curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \ - -H 'Content-Type: application/json' \ - -H 'X-RoleMesh-Node-Key: change-me-node-key' \ - -d '{ - "node_id": "gpu101-planner", - "base_url": "http://192.168.1.101:8091", - "roles": ["planner"] - }' +curl -sS http://192.168.1.101:8091/v1/node/registration ``` -Register writer: +That endpoint returns the exact `served_models` payload the node agent will send to the gateway. -```bash -curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \ - -H 'Content-Type: application/json' \ - -H 'X-RoleMesh-Node-Key: change-me-node-key' \ - -d '{ - "node_id": "gpu101-writer", - "base_url": "http://192.168.1.101:8092", - "roles": ["writer"] - }' -``` - -Register critic: - -```bash -curl -sS -X POST http://192.168.1.100:8080/v1/nodes/register \ - -H 'Content-Type: application/json' \ - -H 'X-RoleMesh-Node-Key: change-me-node-key' \ - -d '{ - "node_id": "gpu102-critic", - "base_url": "http://192.168.1.102:8091", - "roles": ["critic"] - }' -``` - -After that, the heartbeat loop on each node agent keeps the registry entry fresh. +Manual `POST /v1/nodes/register` is still supported, but it is now mainly useful for custom tooling or debugging. ## 7. Verify the topology @@ -267,6 +237,12 @@ That endpoint shows: - device metrics - queue depth and in-flight work +You can also inspect the registration payload directly: + +```bash +curl -sS http://192.168.1.101:8091/v1/node/registration +``` + ## 8. Send requests by role Planner request through the gateway: diff --git a/docs/NODE_AGENT.md b/docs/NODE_AGENT.md index 979adb4..54ed6c6 100644 --- a/docs/NODE_AGENT.md +++ b/docs/NODE_AGENT.md @@ -90,6 +90,21 @@ pip install -e . rolemesh-node-agent --config configs/node_agent.example.yaml ``` +To keep local paths and host-specific values out of the tracked config, use an override file: + +```bash +rolemesh-node-agent \ + --config configs/node_agent.example.yaml \ + --config-override configs/node_agent.local.yaml +``` + +Tracked base config should contain placeholders such as: + +- `/path/to/model-weights` +- `/path/to/llama-server` + +Your local override should contain the real machine-specific values. + ### Startup timing guards Two config knobs control how long the node agent waits for a managed `llama-server` to become ready: @@ -108,11 +123,19 @@ The readiness probe checks the managed server's local `GET /health` and `GET /v1 ## Registering -If `dispatcher_base_url` is set in the node-agent config, the node agent will periodically call: +If `dispatcher_base_url` is set in the node-agent config, the node agent will: +- register itself on startup via `POST /v1/nodes/register` - `POST /v1/nodes/heartbeat` with latest device metrics. -Registration is currently manual from the node side (or can be added as a startup step). +The registration payload is derived from local `models[]` and advertises `served_models`, where each local +`model_id` lists the roles it can satisfy. + +For operator inspection, the node agent also exposes: + +- `GET /v1/node/registration` + +That endpoint returns the exact registration payload the node would send to the dispatcher. ### Binding diff --git a/src/rolemesh_gateway/api/openai.py b/src/rolemesh_gateway/api/openai.py index cd7d097..ca7bc16 100644 --- a/src/rolemesh_gateway/api/openai.py +++ b/src/rolemesh_gateway/api/openai.py @@ -52,8 +52,8 @@ async def _model_status( error = "no_fresh_registered_nodes" return {"alias": alias, "available": False, "role": entry.role, "error": error} - for node in matching_nodes: - base_url = str(node.base_url).rstrip("/") + for candidate in registry.candidates_for_role(entry.role, include_stale=False): + base_url = str(candidate.base_url).rstrip("/") try: await upstream.get_models(base_url) return { @@ -61,7 +61,8 @@ async def _model_status( "available": True, "role": entry.role, "base_url": base_url, - "node_id": node.node_id, + "node_id": candidate.node_id, + "upstream_model_id": candidate.model_id, } except UpstreamError: continue @@ -160,15 +161,16 @@ async def chat_completions(request: Request, _=Depends(require_client_auth)) -> if isinstance(entry, ProxyModel): base_url = str(entry.proxy_url).rstrip("/") elif isinstance(entry, DiscoveredModel): - node = registry.pick_node_for_role(entry.role, strategy=entry.strategy) - if not node: + candidate = registry.pick_candidate_for_role(entry.role, strategy=entry.strategy) + if not candidate: return _openai_error( f"No registered nodes available for role '{entry.role}'. " f"Register a node via POST /v1/nodes/register, or use proxy mode.", code="no_upstream", status_code=503, ) - base_url = str(node.base_url).rstrip("/") + base_url = str(candidate.base_url).rstrip("/") + body["model"] = candidate.model_id else: return _openai_error("Invalid model configuration.", code="bad_config", status_code=500) diff --git a/src/rolemesh_gateway/cli.py b/src/rolemesh_gateway/cli.py index e8d46ef..7fba30b 100644 --- a/src/rolemesh_gateway/cli.py +++ b/src/rolemesh_gateway/cli.py @@ -11,11 +11,12 @@ from rolemesh_gateway.main import create_app def main() -> None: p = argparse.ArgumentParser(description="RoleMesh Gateway") p.add_argument("--config", required=True, help="Path to gateway YAML config.") + p.add_argument("--config-override", help="Optional local override YAML merged over the base config.") p.add_argument("--host", default="127.0.0.1") p.add_argument("--port", type=int, default=8080) args = p.parse_args() - app = create_app(Path(args.config)) + app = create_app(Path(args.config), config_override_path=Path(args.config_override) if args.config_override else None) uvicorn.run(app, host=args.host, port=args.port) diff --git a/src/rolemesh_gateway/config.py b/src/rolemesh_gateway/config.py index 078d353..c6ee8fd 100644 --- a/src/rolemesh_gateway/config.py +++ b/src/rolemesh_gateway/config.py @@ -54,9 +54,28 @@ class Config(BaseModel): models: Dict[str, ModelEntry] = Field(default_factory=dict) -def load_config(path: str | Path) -> Config: - path = Path(path) +def _read_yaml_mapping(path: Path) -> dict: data = yaml.safe_load(path.read_text()) + if data is None: + return {} if not isinstance(data, dict): raise ValueError(f"Invalid config file at {path}: expected a YAML mapping at top level.") + return data + + +def _deep_merge(base: dict, override: dict) -> dict: + merged = dict(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = _deep_merge(merged[key], value) + else: + merged[key] = value + return merged + + +def load_config(path: str | Path, override_path: str | Path | None = None) -> Config: + path = Path(path) + data = _read_yaml_mapping(path) + if override_path is not None: + data = _deep_merge(data, _read_yaml_mapping(Path(override_path))) return Config.model_validate(data) diff --git a/src/rolemesh_gateway/main.py b/src/rolemesh_gateway/main.py index 7f74fff..c19ea99 100644 --- a/src/rolemesh_gateway/main.py +++ b/src/rolemesh_gateway/main.py @@ -26,11 +26,13 @@ def _get_logger() -> logging.Logger: def create_app( config_path: str | Path | None = None, + config_override_path: str | Path | None = None, registry_path: str | Path | None = None, registry_stale_after_s: float | None = None, ) -> FastAPI: cfg_path = config_path or os.environ.get("ROLE_MESH_CONFIG", "configs/models.yaml") - cfg = load_config(cfg_path) + cfg_override_path = config_override_path or os.environ.get("ROLE_MESH_CONFIG_OVERRIDE") + cfg = load_config(cfg_path, override_path=cfg_override_path) resolved_registry_path = registry_path or os.environ.get("ROLE_MESH_REGISTRY_PATH", "state/registry.json") resolved_registry_stale_after_s = registry_stale_after_s diff --git a/src/rolemesh_gateway/registry.py b/src/rolemesh_gateway/registry.py index 5488fc2..fd6c3e3 100644 --- a/src/rolemesh_gateway/registry.py +++ b/src/rolemesh_gateway/registry.py @@ -9,10 +9,17 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, HttpUrl +class ServedModel(BaseModel): + model_id: str + roles: List[str] = Field(default_factory=list) + meta: Dict[str, str] = Field(default_factory=dict) + + class NodeRegistration(BaseModel): node_id: str base_url: HttpUrl # OpenAI-compatible upstream base, e.g. http://10.0.0.12:8011 - roles: List[str] # roles served by this node, e.g. ["planner", "writer"] + served_models: List[ServedModel] = Field(default_factory=list) + roles: List[str] = Field(default_factory=list) # legacy compatibility input meta: Dict[str, str] = Field(default_factory=dict) @@ -28,13 +35,22 @@ class NodeHeartbeat(BaseModel): class RegisteredNode(BaseModel): node_id: str base_url: HttpUrl - roles: List[str] + served_models: List[ServedModel] = Field(default_factory=list) + roles: List[str] = Field(default_factory=list) meta: Dict[str, str] = Field(default_factory=dict) status: Dict[str, Any] = Field(default_factory=dict) registered_at: float = Field(default_factory=lambda: time.time()) last_seen: float = Field(default_factory=lambda: time.time()) +class RoleCandidate(BaseModel): + node_id: str + base_url: HttpUrl + role: str + model_id: str + model_meta: Dict[str, str] = Field(default_factory=dict) + + class Registry: """ Minimal in-memory registry with optional JSON persistence. @@ -60,7 +76,15 @@ class Registry: try: raw = json.loads(self._persist_path.read_text()) for node_id, node_data in raw.get("nodes", {}).items(): - self._nodes[node_id] = RegisteredNode.model_validate(node_data) + node = RegisteredNode.model_validate(node_data) + if not node.served_models and node.roles: + node.served_models = [ + ServedModel(model_id=role, roles=[role]) + for role in node.roles + ] + if not node.roles and node.served_models: + node.roles = self._derive_roles(node.served_models, []) + self._nodes[node_id] = node self._rr_counters = dict(raw.get("rr_counters", {})) except Exception: # If persistence is corrupted, start empty (do not crash the gateway). @@ -77,11 +101,31 @@ class Registry: self._persist_path.parent.mkdir(parents=True, exist_ok=True) self._persist_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) + @staticmethod + def _derive_served_models(reg: NodeRegistration) -> List[ServedModel]: + if reg.served_models: + return reg.served_models + return [ServedModel(model_id=role, roles=[role]) for role in reg.roles] + + @staticmethod + def _derive_roles(served_models: List[ServedModel], legacy_roles: List[str]) -> List[str]: + roles: list[str] = [] + for served_model in served_models: + for role in served_model.roles: + if role not in roles: + roles.append(role) + for role in legacy_roles: + if role not in roles: + roles.append(role) + return roles + def register(self, reg: NodeRegistration) -> RegisteredNode: + served_models = self._derive_served_models(reg) node = RegisteredNode( node_id=reg.node_id, base_url=reg.base_url, - roles=reg.roles, + served_models=served_models, + roles=self._derive_roles(served_models, reg.roles), meta=reg.meta, last_seen=time.time(), ) @@ -119,6 +163,35 @@ class Registry: nodes = self.list_nodes(include_stale=include_stale) return [node for node in nodes if role in node.roles] + def candidates_for_role(self, role: str, *, include_stale: bool = False) -> List[RoleCandidate]: + nodes = self.list_nodes(include_stale=include_stale) + candidates: list[RoleCandidate] = [] + for node in nodes: + for served_model in node.served_models: + if role not in served_model.roles: + continue + candidates.append( + RoleCandidate( + node_id=node.node_id, + base_url=node.base_url, + role=role, + model_id=served_model.model_id, + model_meta=served_model.meta, + ) + ) + return candidates + + def pick_candidate_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RoleCandidate]: + candidates = self.candidates_for_role(role, include_stale=False) + if not candidates: + return None + if strategy == "random": + return random.choice(candidates) + idx = self._rr_counters.get(role, 0) % len(candidates) + self._rr_counters[role] = idx + 1 + self._save() + return candidates[idx] + def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]: candidates = self.nodes_for_role(role, include_stale=False) if not candidates: diff --git a/src/rolemesh_node_agent/cli.py b/src/rolemesh_node_agent/cli.py index 3090e44..48df079 100644 --- a/src/rolemesh_node_agent/cli.py +++ b/src/rolemesh_node_agent/cli.py @@ -1,22 +1,20 @@ from __future__ import annotations import argparse -from pathlib import Path import uvicorn -import yaml -from .config import NodeAgentConfig +from .config import load_config from .main import create_app def main() -> None: p = argparse.ArgumentParser(description="RoleMesh Node Agent") p.add_argument("--config", required=True, help="Path to node-agent YAML config.") + p.add_argument("--config-override", help="Optional local override YAML merged over the base config.") args = p.parse_args() - cfg_path = Path(args.config) - cfg = NodeAgentConfig.model_validate(yaml.safe_load(cfg_path.read_text())) + cfg = load_config(args.config, override_path=args.config_override) app = create_app(cfg) uvicorn.run(app, host=cfg.listen_host, port=cfg.listen_port) diff --git a/src/rolemesh_node_agent/config.py b/src/rolemesh_node_agent/config.py index ae5bba7..1a5cd5e 100644 --- a/src/rolemesh_node_agent/config.py +++ b/src/rolemesh_node_agent/config.py @@ -3,6 +3,7 @@ from __future__ import annotations from pathlib import Path from typing import Any, Dict, List, Literal, Optional +import yaml from pydantic import BaseModel, Field, HttpUrl @@ -51,3 +52,30 @@ class NodeAgentConfig(BaseModel): llama_server_startup_timeout_s: float = 30.0 llama_server_probe_interval_s: float = 0.5 max_pending_requests_per_device: int = 2 + + +def _read_yaml_mapping(path: Path) -> dict: + data = yaml.safe_load(path.read_text()) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError(f"Invalid config file at {path}: expected a YAML mapping at top level.") + return data + + +def _deep_merge(base: dict, override: dict) -> dict: + merged = dict(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(merged.get(key), dict): + merged[key] = _deep_merge(merged[key], value) + else: + merged[key] = value + return merged + + +def load_config(path: str | Path, override_path: str | Path | None = None) -> NodeAgentConfig: + path = Path(path) + data = _read_yaml_mapping(path) + if override_path is not None: + data = _deep_merge(data, _read_yaml_mapping(Path(override_path))) + return NodeAgentConfig.model_validate(data) diff --git a/src/rolemesh_node_agent/main.py b/src/rolemesh_node_agent/main.py index ae3fdb6..4b3bbb4 100644 --- a/src/rolemesh_node_agent/main.py +++ b/src/rolemesh_node_agent/main.py @@ -39,6 +39,46 @@ def _merge_scheduler_metrics( return out +def _effective_served_models(cfg: NodeAgentConfig) -> list[dict[str, Any]]: + allowed_roles = set(cfg.dispatcher_roles) + served_models: list[dict[str, Any]] = [] + for model in cfg.models: + roles = list(model.roles) + if allowed_roles: + roles = [role for role in roles if role in allowed_roles] + if not roles: + continue + meta: dict[str, str] = {} + if model.alias: + meta["alias"] = model.alias + served_models.append( + { + "model_id": model.model_id, + "roles": roles, + "meta": meta, + } + ) + return served_models + + +def _registration_payload(cfg: NodeAgentConfig) -> dict[str, Any]: + served_models = _effective_served_models(cfg) + roles: list[str] = [] + for served_model in served_models: + for role in served_model["roles"]: + if role not in roles: + roles.append(role) + return { + "node_id": cfg.node_id, + "base_url": f"http://{cfg.listen_host}:{cfg.listen_port}", + "served_models": served_models, + "roles": roles, + "meta": { + "node_agent": "rolemesh", + }, + } + + def _select_device( devices: Iterable[DeviceRef], metrics: Iterable[DeviceMetrics], @@ -86,7 +126,8 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI: @asynccontextmanager async def lifespan(app: FastAPI): heartbeat_task: asyncio.Task[None] | None = None - if cfg.dispatcher_base_url and cfg.dispatcher_roles: + if cfg.dispatcher_base_url: + await _register_with_dispatcher(app) heartbeat_task = asyncio.create_task(_heartbeat_loop(app)) try: yield @@ -116,6 +157,10 @@ def create_app(cfg: NodeAgentConfig) -> FastAPI: async def health() -> Dict[str, Any]: return {"status": "ok", "node_id": cfg.node_id, "time": time.time()} + @app.get("/v1/node/registration") + async def registration_payload() -> Dict[str, Any]: + return _registration_payload(cfg) + @app.get("/v1/node/inventory") async def inventory() -> Dict[str, Any]: devices = await app.state.cuda.discover_devices() @@ -221,3 +266,18 @@ async def _heartbeat_loop(app: FastAPI) -> None: except Exception: pass await asyncio.sleep(cfg.heartbeat_interval_sec) + + +async def _register_with_dispatcher(app: FastAPI) -> None: + cfg: NodeAgentConfig = app.state.cfg + http: httpx.AsyncClient = app.state.http + if not cfg.dispatcher_base_url: + return + try: + url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/register" + headers = {} + if cfg.dispatcher_node_key: + headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key + await http.post(url, json=_registration_payload(cfg), headers=headers) + except Exception: + pass diff --git a/tests/test_config_loading.py b/tests/test_config_loading.py new file mode 100644 index 0000000..9b1ecca --- /dev/null +++ b/tests/test_config_loading.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml + +from rolemesh_gateway.config import load_config as load_gateway_config +from rolemesh_node_agent.config import load_config as load_node_config + + +def test_gateway_config_override_merges_nested_mappings(tmp_path: Path): + base = tmp_path / "models.yaml" + override = tmp_path / "models.local.yaml" + base.write_text( + yaml.safe_dump( + { + "version": 1, + "default_model": "writer", + "auth": { + "client_api_keys": ["base-client-key"], + "node_api_keys": ["base-node-key"], + }, + "models": { + "writer": { + "type": "proxy", + "openai_model_name": "writer", + "proxy_url": "http://127.0.0.1:8012", + "defaults": {"temperature": 0.6, "max_tokens": 256}, + } + }, + } + ) + ) + override.write_text( + yaml.safe_dump( + { + "auth": { + "client_api_keys": ["local-client-key"], + }, + "models": { + "writer": { + "proxy_url": "http://192.168.1.50:8012", + "defaults": {"temperature": 0.2}, + } + }, + } + ) + ) + + cfg = load_gateway_config(base, override_path=override) + + assert cfg.auth.client_api_keys == ["local-client-key"] + assert cfg.auth.node_api_keys == ["base-node-key"] + assert str(cfg.models["writer"].proxy_url) == "http://192.168.1.50:8012/" + assert cfg.models["writer"].defaults == {"temperature": 0.2, "max_tokens": 256} + + +def test_node_agent_config_override_replaces_local_machine_paths(tmp_path: Path): + base = tmp_path / "node_agent.yaml" + override = tmp_path / "node_agent.local.yaml" + base.write_text( + yaml.safe_dump( + { + "node_id": "node-generic", + "listen_host": "0.0.0.0", + "listen_port": 8091, + "dispatcher_base_url": "http://10.0.0.10:8080", + "llama_server_bin": "/path/to/llama-server", + "model_roots": ["/path/to/model-weights"], + "models": [ + { + "model_id": "qwen3-8b", + "path": "/path/to/model-weights/Qwen3-8B-Q5_K_M.gguf", + "roles": ["tutor"], + } + ], + } + ) + ) + override.write_text( + yaml.safe_dump( + { + "listen_host": "192.168.1.101", + "llama_server_bin": "/home/netuser/bin/llama.cpp/build/bin/llama-server", + "model_roots": ["/home/netuser/bin/models/llm"], + "models": [ + { + "model_id": "qwen3-8b", + "path": "/home/netuser/bin/models/llm/Qwen3-8B-Q5_K_M.gguf", + "roles": ["tutor", "mentor"], + } + ], + } + ) + ) + + cfg = load_node_config(base, override_path=override) + + assert cfg.listen_host == "192.168.1.101" + assert cfg.llama_server_bin == "/home/netuser/bin/llama.cpp/build/bin/llama-server" + assert [str(path) for path in cfg.model_roots] == ["/home/netuser/bin/models/llm"] + assert str(cfg.models[0].path) == "/home/netuser/bin/models/llm/Qwen3-8B-Q5_K_M.gguf" + assert cfg.models[0].roles == ["tutor", "mentor"] diff --git a/tests/test_gateway.py b/tests/test_gateway.py index ebae53d..998e3e5 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -325,6 +325,7 @@ def test_discovered_model_uses_random_strategy_when_configured(tmp_path): async def fake_chat(base_url, payload): calls["base_url"] = base_url + calls["payload"] = payload return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} app.state.upstream.chat_completions = fake_chat @@ -348,4 +349,57 @@ def test_discovered_model_uses_random_strategy_when_configured(tmp_path): assert response.status_code == 200 assert calls["base_url"] == "http://127.0.0.1:9002" + assert calls["payload"]["model"] == "reviewer" + asyncio.run(app.state.upstream.close()) + + +def test_discovered_model_rewrites_alias_to_upstream_model_id(tmp_path): + app = _create_gateway_app(tmp_path) + calls = {} + + register = asyncio.run( + _request( + app, + "POST", + "/v1/nodes/register", + headers={"x-rolemesh-node-key": "node-secret"}, + json={ + "node_id": "node-a", + "base_url": "http://127.0.0.1:9001", + "served_models": [ + { + "model_id": "qwen3-8b", + "roles": ["reviewer"], + "meta": {"family": "Qwen3"}, + } + ], + }, + ) + ) + assert register.status_code == 200 + + async def fake_chat(base_url, payload): + calls["base_url"] = base_url + calls["payload"] = payload + return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]} + + app.state.upstream.chat_completions = fake_chat + app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []} + + response = asyncio.run( + _request( + app, + "POST", + "/v1/chat/completions", + headers={"x-api-key": "client-secret"}, + json={ + "model": "reviewer", + "messages": [{"role": "user", "content": "hello"}], + }, + ) + ) + + assert response.status_code == 200 + assert calls["base_url"] == "http://127.0.0.1:9001" + assert calls["payload"]["model"] == "qwen3-8b" asyncio.run(app.state.upstream.close()) diff --git a/tests/test_node_agent.py b/tests/test_node_agent.py index a624c9a..43df0cd 100644 --- a/tests/test_node_agent.py +++ b/tests/test_node_agent.py @@ -8,7 +8,7 @@ import httpx from rolemesh_node_agent.adapters.base import DeviceMetrics, DeviceRef from rolemesh_node_agent.adapters.cuda import CudaAdapter from rolemesh_node_agent.config import ModelEntry, NodeAgentConfig -from rolemesh_node_agent.main import _merge_scheduler_metrics, _select_device +from rolemesh_node_agent.main import _merge_scheduler_metrics, _registration_payload, _select_device from rolemesh_node_agent.scheduler import AdmissionError, DeviceQueue @@ -59,6 +59,74 @@ def test_inventory_reports_models_metrics_and_discovered_gguf(tmp_path): asyncio.run(app.state.http.aclose()) +def test_registration_payload_uses_served_models_from_local_catalog(tmp_path): + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"GGUF") + cfg = NodeAgentConfig( + node_id="node-1", + listen_host="192.168.1.101", + listen_port=8091, + dispatcher_roles=["mentor", "code_tutor"], + models=[ + ModelEntry( + model_id="qwen3-8b", + path=model_path, + roles=["mentor", "tutor"], + alias="mentor", + ), + ModelEntry( + model_id="qwen2.5-coder-14b", + path=model_path, + roles=["code_tutor"], + ), + ], + ) + + payload = _registration_payload(cfg) + + assert payload["node_id"] == "node-1" + assert payload["base_url"] == "http://192.168.1.101:8091" + assert payload["roles"] == ["mentor", "code_tutor"] + assert payload["served_models"] == [ + { + "model_id": "qwen3-8b", + "roles": ["mentor"], + "meta": {"alias": "mentor"}, + }, + { + "model_id": "qwen2.5-coder-14b", + "roles": ["code_tutor"], + "meta": {}, + }, + ] + + +def test_registration_endpoint_returns_computed_payload(tmp_path): + from rolemesh_node_agent.main import create_app + + model_path = tmp_path / "model.gguf" + model_path.write_bytes(b"GGUF") + cfg = NodeAgentConfig( + node_id="node-1", + listen_host="192.168.1.101", + listen_port=8091, + models=[ + ModelEntry(model_id="qwen3-8b", path=model_path, roles=["mentor", "tutor"]), + ], + ) + app = create_app(cfg) + + response = asyncio.run(_request(app, "GET", "/v1/node/registration")) + + assert response.status_code == 200 + body = response.json() + assert body["node_id"] == "node-1" + assert body["base_url"] == "http://192.168.1.101:8091" + assert body["served_models"][0]["model_id"] == "qwen3-8b" + assert body["served_models"][0]["roles"] == ["mentor", "tutor"] + asyncio.run(app.state.http.aclose()) + + def test_chat_completions_routes_to_local_server_and_streams(tmp_path): from rolemesh_node_agent.main import create_app diff --git a/tests/test_registry.py b/tests/test_registry.py index 8287bf2..b185b1a 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -2,7 +2,7 @@ from __future__ import annotations import random -from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry +from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry, ServedModel def test_registry_persists_round_robin_and_heartbeat_state(tmp_path): @@ -98,3 +98,51 @@ def test_registry_supports_random_selection(monkeypatch): picked = registry.pick_node_for_role("reviewer", strategy="random") assert picked is not None assert picked.node_id == "node-b" + + +def test_registry_supports_served_models_and_candidate_selection(): + registry = Registry() + registry.register( + NodeRegistration( + node_id="node-a", + base_url="http://127.0.0.1:9001", + served_models=[ + ServedModel(model_id="qwen3-8b", roles=["tutor", "mentor"]), + ServedModel(model_id="qwen2.5-coder-14b", roles=["code_tutor"]), + ], + ) + ) + + tutor_candidates = registry.candidates_for_role("tutor") + assert len(tutor_candidates) == 1 + assert tutor_candidates[0].node_id == "node-a" + assert tutor_candidates[0].model_id == "qwen3-8b" + + mentor_candidates = registry.candidates_for_role("mentor") + assert len(mentor_candidates) == 1 + assert mentor_candidates[0].model_id == "qwen3-8b" + + code_candidates = registry.candidates_for_role("code_tutor") + assert len(code_candidates) == 1 + assert code_candidates[0].model_id == "qwen2.5-coder-14b" + + picked = registry.pick_candidate_for_role("mentor") + assert picked is not None + assert picked.model_id == "qwen3-8b" + + +def test_registry_loads_legacy_roles_as_served_models(tmp_path): + persist_path = tmp_path / "registry.json" + registry = Registry(persist_path=persist_path) + registry.register( + NodeRegistration( + node_id="node-a", + base_url="http://127.0.0.1:9001", + roles=["reviewer"], + ) + ) + + reloaded = Registry(persist_path=persist_path) + candidates = reloaded.candidates_for_role("reviewer") + assert len(candidates) == 1 + assert candidates[0].model_id == "reviewer"