Change to having node agents and dispatcher structure
This commit is contained in:
parent
8537c6ef39
commit
b2e4518b3a
|
|
@ -64,3 +64,11 @@ This repository is a **preliminary scaffold**:
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT. See `LICENSE`.
|
MIT. See `LICENSE`.
|
||||||
|
|
||||||
|
## Node Agent (per-host)
|
||||||
|
|
||||||
|
This repo also includes a **RoleMesh Node Agent** (`rolemesh-node-agent`) that can manage **persistent** `llama.cpp` servers (one per GPU) and report inventory/metrics back to the gateway.
|
||||||
|
|
||||||
|
- Sample config: `configs/node_agent.example.yaml`
|
||||||
|
- Docs: `docs/NODE_AGENT.md`
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
node_id: "node-1"
|
||||||
|
listen_host: "0.0.0.0"
|
||||||
|
listen_port: 8091
|
||||||
|
|
||||||
|
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
|
||||||
|
dispatcher_base_url: "http://127.0.0.1:8080"
|
||||||
|
dispatcher_roles: ["planner", "coder"]
|
||||||
|
heartbeat_interval_sec: 5
|
||||||
|
|
||||||
|
llama_server_bin: "llama-server"
|
||||||
|
|
||||||
|
model_roots:
|
||||||
|
- "/models"
|
||||||
|
|
||||||
|
models:
|
||||||
|
- model_id: "planner-gguf"
|
||||||
|
path: "/models/SomePlannerModel.Q5_K_M.gguf"
|
||||||
|
roles: ["planner"]
|
||||||
|
default_ctx: 8192
|
||||||
|
server_args:
|
||||||
|
# Examples (llama.cpp flags differ by build/version):
|
||||||
|
# c: 8192
|
||||||
|
# n_gpu_layers: 60
|
||||||
|
# threads: 8
|
||||||
|
# parallel: 1
|
||||||
|
# keep: true
|
||||||
|
c: 8192
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Node Agent
|
||||||
|
|
||||||
|
The **RoleMesh Node Agent** runs on each compute host and manages **persistent** `llama.cpp` servers
|
||||||
|
(one per device, e.g. one per GPU). It can:
|
||||||
|
|
||||||
|
- expose OpenAI-compatible endpoints locally (`/v1/models`, `/v1/chat/completions`)
|
||||||
|
- register + heartbeat to the Dispatcher/Gateway (`/v1/nodes/register`, `/v1/nodes/heartbeat`)
|
||||||
|
- report inventory + utilization (`/v1/node/inventory`)
|
||||||
|
|
||||||
|
## Persistent server model
|
||||||
|
|
||||||
|
For each GPU device, the node agent starts a dedicated `llama-server` process, pinned via
|
||||||
|
environment variables (e.g. `CUDA_VISIBLE_DEVICES=0` for `gpu:0`) and bound to `127.0.0.1:<port>`.
|
||||||
|
|
||||||
|
Model switching is handled by **restart** in the scaffold.
|
||||||
|
|
||||||
|
## Backends
|
||||||
|
|
||||||
|
Adapters are implemented as runtime backends:
|
||||||
|
|
||||||
|
- `cuda`: scaffold implementation (NVIDIA via `nvidia-smi`)
|
||||||
|
- `metal`, `rocm`, `sycl`, `vulkan`: stubs with placeholders for device discovery and metrics
|
||||||
|
|
||||||
|
The framework keeps scheduling decisions backend-agnostic by standardizing on:
|
||||||
|
`DeviceRef` + `DeviceMetrics` + `ensure_server(...)`.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
rolemesh-node-agent --config configs/node_agent.example.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Registering
|
||||||
|
|
||||||
|
If `dispatcher_base_url` is set in the node-agent config, the node agent will periodically call:
|
||||||
|
|
||||||
|
- `POST <dispatcher>/v1/nodes/heartbeat` with latest device metrics.
|
||||||
|
|
||||||
|
Registration is currently manual from the node side (or can be added as a startup step).
|
||||||
|
|
@ -18,6 +18,12 @@ dependencies = [
|
||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
rolemesh-gateway = "rolemesh_gateway.cli:main"
|
||||||
|
rolemesh-node-agent = "rolemesh_node_agent.cli:main"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"ruff>=0.4",
|
"ruff>=0.4",
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
||||||
from rolemesh_gateway.registry import NodeRegistration, Registry
|
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -145,3 +145,19 @@ async def register_node(request: Request) -> Dict[str, Any]:
|
||||||
reg = NodeRegistration.model_validate(payload)
|
reg = NodeRegistration.model_validate(payload)
|
||||||
node = registry.register(reg)
|
node = registry.register(reg)
|
||||||
return {"status": "ok", "node": node.model_dump(mode="json")}
|
return {"status": "ok", "node": node.model_dump(mode="json")}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/nodes/heartbeat")
|
||||||
|
async def nodes_heartbeat(request: Request) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Allow a node agent to push status/metrics updates.
|
||||||
|
|
||||||
|
SECURITY NOTE: Unauthenticated in scaffold. Add API-key gating or mTLS for real deployments.
|
||||||
|
"""
|
||||||
|
registry: Registry = request.app.state.registry
|
||||||
|
payload = await request.json()
|
||||||
|
hb = NodeHeartbeat.model_validate(payload)
|
||||||
|
node = registry.heartbeat(hb)
|
||||||
|
if not node:
|
||||||
|
return JSONResponse(status_code=404, content={"error": "unknown_node", "node_id": hb.node_id})
|
||||||
|
return {"status": "ok", "node": node.model_dump(mode="json")}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from rolemesh_gateway.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description="RoleMesh Gateway")
|
||||||
|
p.add_argument("--config", required=True, help="Path to gateway YAML config.")
|
||||||
|
p.add_argument("--host", default="0.0.0.0")
|
||||||
|
p.add_argument("--port", type=int, default=8080)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
app = create_app(Path(args.config))
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, HttpUrl
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
|
@ -15,11 +15,19 @@ class NodeRegistration(BaseModel):
|
||||||
meta: Dict[str, str] = Field(default_factory=dict)
|
meta: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class NodeHeartbeat(BaseModel):
|
||||||
|
node_id: str
|
||||||
|
timestamp: float = Field(default_factory=lambda: time.time())
|
||||||
|
status: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
||||||
class RegisteredNode(BaseModel):
|
class RegisteredNode(BaseModel):
|
||||||
node_id: str
|
node_id: str
|
||||||
base_url: HttpUrl
|
base_url: HttpUrl
|
||||||
roles: List[str]
|
roles: List[str]
|
||||||
meta: Dict[str, str] = Field(default_factory=dict)
|
meta: Dict[str, str] = Field(default_factory=dict)
|
||||||
|
status: Dict[str, Any] = Field(default_factory=dict)
|
||||||
registered_at: float = Field(default_factory=lambda: time.time())
|
registered_at: float = Field(default_factory=lambda: time.time())
|
||||||
last_seen: float = Field(default_factory=lambda: time.time())
|
last_seen: float = Field(default_factory=lambda: time.time())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
__all__ = []
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
__all__ = []
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DeviceRef:
|
||||||
|
"""A stable identifier for a compute target on a node."""
|
||||||
|
kind: str # 'gpu' | 'cpu' | 'tpu'
|
||||||
|
backend: str # 'cuda' | 'metal' | 'rocm' | 'vulkan' | 'sycl' | 'cpu' | 'edgetpu'
|
||||||
|
id: str # e.g. 'gpu:0', 'cpu:0'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceMetrics:
|
||||||
|
device: DeviceRef
|
||||||
|
loaded_model_id: Optional[str] = None
|
||||||
|
loaded_since: Optional[float] = None
|
||||||
|
|
||||||
|
# Utilization snapshot (best-effort, backend-specific)
|
||||||
|
utilization_pct: Optional[float] = None
|
||||||
|
mem_total_gb: Optional[float] = None
|
||||||
|
mem_used_gb: Optional[float] = None
|
||||||
|
|
||||||
|
# Queueing / in-flight
|
||||||
|
queue_depth: int = 0
|
||||||
|
in_flight_jobs: int = 0
|
||||||
|
|
||||||
|
# Rolling performance hints (observed)
|
||||||
|
median_ttft_ms: Optional[float] = None
|
||||||
|
median_toks_per_sec: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeAdapter(Protocol):
|
||||||
|
"""Backend-specific adapter for starting and managing persistent llama.cpp servers."""
|
||||||
|
|
||||||
|
backend_name: str # 'cuda'|'metal'|'rocm'|'vulkan'|'sycl'|'cpu'|...
|
||||||
|
|
||||||
|
async def discover_devices(self) -> list[DeviceRef]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_metrics(self) -> list[DeviceMetrics]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
"""Ensure a persistent server exists for (device, model). Returns base_url for OpenAI-compatible endpoints."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
...
|
||||||
|
|
@ -0,0 +1,176 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def _find_free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
return int(s.getsockname()[1])
|
||||||
|
|
||||||
|
|
||||||
|
def _nvidia_smi_query() -> List[Dict[str, str]]:
|
||||||
|
"""Best-effort GPU discovery/metrics via nvidia-smi. Returns [] if unavailable."""
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"nvidia-smi",
|
||||||
|
"--query-gpu=index,name,memory.total,memory.used,utilization.gpu",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
]
|
||||||
|
out = subprocess.check_output(cmd, text=True).strip()
|
||||||
|
if not out:
|
||||||
|
return []
|
||||||
|
rows = []
|
||||||
|
for line in out.splitlines():
|
||||||
|
idx, name, mem_total, mem_used, util = [c.strip() for c in line.split(",")]
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"index": idx,
|
||||||
|
"name": name,
|
||||||
|
"mem_total_mb": mem_total,
|
||||||
|
"mem_used_mb": mem_used,
|
||||||
|
"util_pct": util,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ServerProc:
|
||||||
|
device: DeviceRef
|
||||||
|
model_id: str
|
||||||
|
model_path: str
|
||||||
|
port: int
|
||||||
|
proc: subprocess.Popen
|
||||||
|
started_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class CudaAdapter(RuntimeAdapter):
|
||||||
|
"""
|
||||||
|
CUDA adapter that manages one persistent ggml-org/llama.cpp 'llama-server' per GPU.
|
||||||
|
|
||||||
|
Assumptions:
|
||||||
|
- 'llama-server' binary is available in PATH on the node.
|
||||||
|
- Each server is bound to 127.0.0.1:<port>.
|
||||||
|
- Model switching is performed by stopping/restarting the server for that GPU.
|
||||||
|
(More advanced approaches can be added later.)
|
||||||
|
|
||||||
|
NOTE: This is a scaffold implementation intended to be extended with:
|
||||||
|
- more robust health checks
|
||||||
|
- structured logging
|
||||||
|
- graceful shutdown and model warm pools
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_name = "cuda"
|
||||||
|
|
||||||
|
def __init__(self, llama_server_bin: str = "llama-server") -> None:
|
||||||
|
self._bin = llama_server_bin
|
||||||
|
self._servers: Dict[str, _ServerProc] = {} # key: device.id
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def discover_devices(self) -> List[DeviceRef]:
|
||||||
|
gpus = _nvidia_smi_query()
|
||||||
|
devices: List[DeviceRef] = []
|
||||||
|
for g in gpus:
|
||||||
|
devices.append(DeviceRef(kind="gpu", backend="cuda", id=f"gpu:{g['index']}"))
|
||||||
|
return devices
|
||||||
|
|
||||||
|
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||||
|
gpus = _nvidia_smi_query()
|
||||||
|
metrics: List[DeviceMetrics] = []
|
||||||
|
for g in gpus:
|
||||||
|
dev = DeviceRef(kind="gpu", backend="cuda", id=f"gpu:{g['index']}")
|
||||||
|
m = DeviceMetrics(
|
||||||
|
device=dev,
|
||||||
|
utilization_pct=float(g["util_pct"]) if g.get("util_pct") else None,
|
||||||
|
mem_total_gb=float(g["mem_total_mb"]) / 1024.0 if g.get("mem_total_mb") else None,
|
||||||
|
mem_used_gb=float(g["mem_used_mb"]) / 1024.0 if g.get("mem_used_mb") else None,
|
||||||
|
)
|
||||||
|
sp = self._servers.get(dev.id)
|
||||||
|
if sp:
|
||||||
|
m.loaded_model_id = sp.model_id
|
||||||
|
m.loaded_since = sp.started_at
|
||||||
|
metrics.append(m)
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
if device.backend != "cuda" or device.kind != "gpu":
|
||||||
|
raise ValueError("CudaAdapter can only manage cuda gpu devices.")
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._servers.get(device.id)
|
||||||
|
if existing and existing.model_id == model_id and existing.proc.poll() is None:
|
||||||
|
return f"http://127.0.0.1:{existing.port}"
|
||||||
|
|
||||||
|
# Stop old server if present
|
||||||
|
if existing and existing.proc.poll() is None:
|
||||||
|
existing.proc.terminate()
|
||||||
|
try:
|
||||||
|
existing.proc.wait(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
existing.proc.kill()
|
||||||
|
|
||||||
|
port = _find_free_port()
|
||||||
|
|
||||||
|
# llama-server flag conventions vary; keep args configurable.
|
||||||
|
# Common options include: -m <model>, --port, --host, -ngl (gpu layers), -c (ctx)
|
||||||
|
# We also pin to a specific GPU using CUDA_VISIBLE_DEVICES for that process.
|
||||||
|
env = os.environ.copy()
|
||||||
|
gpu_index = device.id.split(":")[1]
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = gpu_index
|
||||||
|
|
||||||
|
cmd = [self._bin, "-m", model_path, "--host", "127.0.0.1", "--port", str(port)]
|
||||||
|
for k, v in (server_args or {}).items():
|
||||||
|
if v is None or v is False:
|
||||||
|
continue
|
||||||
|
flag = str(k)
|
||||||
|
if not flag.startswith("-"):
|
||||||
|
flag = "--" + flag
|
||||||
|
if v is True:
|
||||||
|
cmd.append(flag)
|
||||||
|
else:
|
||||||
|
cmd.extend([flag, str(v)])
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
text=True,
|
||||||
|
bufsize=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._servers[device.id] = _ServerProc(
|
||||||
|
device=device,
|
||||||
|
model_id=model_id,
|
||||||
|
model_path=model_path,
|
||||||
|
port=port,
|
||||||
|
proc=proc,
|
||||||
|
started_at=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: replace with a real readiness probe (e.g., GET /health on llama-server)
|
||||||
|
await asyncio.sleep(0.25)
|
||||||
|
return f"http://127.0.0.1:{port}"
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
for sp in list(self._servers.values()):
|
||||||
|
if sp.proc.poll() is None:
|
||||||
|
sp.proc.terminate()
|
||||||
|
try:
|
||||||
|
sp.proc.wait(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
sp.proc.kill()
|
||||||
|
self._servers.clear()
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class MetalAdapter(RuntimeAdapter):
|
||||||
|
"""
|
||||||
|
Stub adapter for backend 'metal'.
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||||
|
compiled with the appropriate backend enabled (metal).
|
||||||
|
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||||
|
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_name = "metal"
|
||||||
|
|
||||||
|
async def discover_devices(self) -> List[DeviceRef]:
|
||||||
|
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||||
|
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAdapter(RuntimeAdapter):
|
||||||
|
"""
|
||||||
|
Stub adapter for backend 'rocm'.
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||||
|
compiled with the appropriate backend enabled (rocm).
|
||||||
|
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||||
|
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_name = "rocm"
|
||||||
|
|
||||||
|
async def discover_devices(self) -> List[DeviceRef]:
|
||||||
|
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||||
|
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class SyclAdapter(RuntimeAdapter):
|
||||||
|
"""
|
||||||
|
Stub adapter for backend 'sycl'.
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||||
|
compiled with the appropriate backend enabled (sycl).
|
||||||
|
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||||
|
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_name = "sycl"
|
||||||
|
|
||||||
|
async def discover_devices(self) -> List[DeviceRef]:
|
||||||
|
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||||
|
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class VulkanAdapter(RuntimeAdapter):
|
||||||
|
"""
|
||||||
|
Stub adapter for backend 'vulkan'.
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||||
|
compiled with the appropriate backend enabled (vulkan).
|
||||||
|
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||||
|
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_name = "vulkan"
|
||||||
|
|
||||||
|
async def discover_devices(self) -> List[DeviceRef]:
|
||||||
|
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||||
|
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||||
|
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from .config import NodeAgentConfig
|
||||||
|
from .main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(description="RoleMesh Node Agent")
|
||||||
|
p.add_argument("--config", required=True, help="Path to node-agent YAML config.")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
cfg_path = Path(args.config)
|
||||||
|
cfg = NodeAgentConfig.model_validate(yaml.safe_load(cfg_path.read_text()))
|
||||||
|
app = create_app(cfg)
|
||||||
|
uvicorn.run(app, host=cfg.listen_host, port=cfg.listen_port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEntry(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
path: Path
|
||||||
|
roles: List[str] = Field(default_factory=list)
|
||||||
|
default_ctx: int = 8192
|
||||||
|
server_args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class NodeAgentConfig(BaseModel):
|
||||||
|
node_id: str = "node-1"
|
||||||
|
listen_host: str = "0.0.0.0"
|
||||||
|
listen_port: int = 8091
|
||||||
|
|
||||||
|
# Where GGUF models live (used for inventory endpoints; not required if models are explicit)
|
||||||
|
model_roots: List[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# Runtime backends enabled on this node
|
||||||
|
enable_backends: List[Literal["cuda", "metal", "rocm", "vulkan", "sycl", "cpu"]] = Field(default_factory=lambda: ["cuda"])
|
||||||
|
|
||||||
|
# Explicit model catalog (recommended)
|
||||||
|
models: List[ModelEntry] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# Optional dispatcher registration/heartbeat
|
||||||
|
dispatcher_base_url: Optional[HttpUrl] = None
|
||||||
|
dispatcher_roles: List[str] = Field(default_factory=list)
|
||||||
|
heartbeat_interval_sec: float = 5.0
|
||||||
|
|
||||||
|
# llama-server binary name/path
|
||||||
|
llama_server_bin: str = "llama-server"
|
||||||
|
|
@ -0,0 +1,37 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
from .config import ModelEntry
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256(path: Path, max_bytes: int = 2_000_000) -> str:
|
||||||
|
"""Best-effort short hash; reads only the first max_bytes to keep inventory fast."""
|
||||||
|
h = hashlib.sha256()
|
||||||
|
with path.open("rb") as f:
|
||||||
|
h.update(f.read(max_bytes))
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def discover_gguf_models(roots: Iterable[Path]) -> List[Dict[str, str]]:
|
||||||
|
out: List[Dict[str, str]] = []
|
||||||
|
for r in roots:
|
||||||
|
r = Path(r)
|
||||||
|
if not r.exists():
|
||||||
|
continue
|
||||||
|
for p in r.rglob("*.gguf"):
|
||||||
|
try:
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"path": str(p),
|
||||||
|
"name": p.name,
|
||||||
|
"size_bytes": str(p.stat().st_size),
|
||||||
|
"sha256_head": _sha256(p),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
|
||||||
|
from .adapters.cuda import CudaAdapter
|
||||||
|
from .adapters.base import DeviceRef
|
||||||
|
from .config import NodeAgentConfig
|
||||||
|
from .inventory import discover_gguf_models
|
||||||
|
|
||||||
|
|
||||||
|
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status_code,
|
||||||
|
content={"error": {"message": message, "type": "node_error", "code": code}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
||||||
|
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0")
|
||||||
|
|
||||||
|
app.state.cfg = cfg
|
||||||
|
app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0))
|
||||||
|
app.state.upstream = UpstreamClient(client=app.state.http)
|
||||||
|
|
||||||
|
# Adapters
|
||||||
|
app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
||||||
|
|
||||||
|
# State: role -> (device, model)
|
||||||
|
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
||||||
|
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def _startup() -> None:
|
||||||
|
# optional: dispatcher registration loop
|
||||||
|
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||||
|
asyncio.create_task(_heartbeat_loop(app))
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def _shutdown() -> None:
|
||||||
|
await app.state.http.aclose()
|
||||||
|
await app.state.cuda.shutdown()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> Dict[str, Any]:
|
||||||
|
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
||||||
|
|
||||||
|
@app.get("/v1/node/inventory")
|
||||||
|
async def inventory() -> Dict[str, Any]:
|
||||||
|
devices = await app.state.cuda.discover_devices()
|
||||||
|
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
|
||||||
|
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
|
||||||
|
discovered = discover_gguf_models(cfg.model_roots)
|
||||||
|
return {
|
||||||
|
"node_id": cfg.node_id,
|
||||||
|
"backends": cfg.enable_backends,
|
||||||
|
"devices": [d.__dict__ for d in devices],
|
||||||
|
"metrics": metrics,
|
||||||
|
"models": models,
|
||||||
|
"discovered_gguf": discovered,
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models() -> Dict[str, Any]:
|
||||||
|
# Expose configured models as OpenAI models (node-local names).
|
||||||
|
data = [{"id": m.model_id, "object": "model", "owned_by": cfg.node_id} for m in cfg.models]
|
||||||
|
return {"object": "list", "data": data}
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: Request) -> Any:
|
||||||
|
body = await request.json()
|
||||||
|
stream = bool(body.get("stream", False))
|
||||||
|
model_id = body.get("model")
|
||||||
|
if not model_id:
|
||||||
|
return _error("Missing 'model' in request body.", code="bad_request", status_code=400)
|
||||||
|
|
||||||
|
# Find model entry
|
||||||
|
model_entry = next((m for m in cfg.models if m.model_id == model_id), None)
|
||||||
|
if not model_entry:
|
||||||
|
return _error(f"Unknown model_id '{model_id}'.", code="unknown_model", status_code=404)
|
||||||
|
|
||||||
|
# Select device (first CUDA GPU for now)
|
||||||
|
devices = await app.state.cuda.discover_devices()
|
||||||
|
if not devices:
|
||||||
|
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
|
||||||
|
device = devices[0]
|
||||||
|
|
||||||
|
base_url = await app.state.cuda.ensure_server(
|
||||||
|
device,
|
||||||
|
model_path=str(model_entry.path),
|
||||||
|
model_id=model_entry.model_id,
|
||||||
|
server_args=model_entry.server_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
upstream = app.state.upstream
|
||||||
|
try:
|
||||||
|
if not stream:
|
||||||
|
out = await upstream.chat_completions(base_url, body)
|
||||||
|
return JSONResponse(status_code=200, content=out)
|
||||||
|
else:
|
||||||
|
async def gen():
|
||||||
|
async for chunk in upstream.stream_chat_completions(base_url, body):
|
||||||
|
yield chunk
|
||||||
|
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||||
|
except UpstreamError as e:
|
||||||
|
return _error(str(e), code="upstream_error", status_code=502)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def _heartbeat_loop(app: FastAPI) -> None:
|
||||||
|
cfg: NodeAgentConfig = app.state.cfg
|
||||||
|
http: httpx.AsyncClient = app.state.http
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
inv = await app.state.cuda.get_metrics()
|
||||||
|
payload = {
|
||||||
|
"node_id": cfg.node_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
||||||
|
}
|
||||||
|
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
||||||
|
await http.post(url, json=payload)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Job:
|
||||||
|
job_id: str
|
||||||
|
submitted_at: float
|
||||||
|
base_url: str
|
||||||
|
body: Dict[str, Any]
|
||||||
|
stream: bool
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceQueue:
|
||||||
|
"""Simple FIFO queue per device; one worker per device for latency predictability."""
|
||||||
|
|
||||||
|
def __init__(self, *, client: httpx.AsyncClient, base_url: str) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._base_url = base_url.rstrip("/")
|
||||||
|
self._q: asyncio.Queue[Job] = asyncio.Queue()
|
||||||
|
self._worker_task: Optional[asyncio.Task] = None
|
||||||
|
self.in_flight: int = 0
|
||||||
|
self.last_job_started_at: Optional[float] = None
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self._worker_task is None:
|
||||||
|
self._worker_task = asyncio.create_task(self._worker())
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._worker_task:
|
||||||
|
self._worker_task.cancel()
|
||||||
|
self._worker_task = None
|
||||||
|
|
||||||
|
async def submit(self, job: Job) -> Job:
|
||||||
|
await self._q.put(job)
|
||||||
|
return job
|
||||||
|
|
||||||
|
@property
|
||||||
|
def depth(self) -> int:
|
||||||
|
return self._q.qsize()
|
||||||
|
|
||||||
|
async def _worker(self) -> None:
|
||||||
|
while True:
|
||||||
|
job = await self._q.get()
|
||||||
|
self.in_flight += 1
|
||||||
|
self.last_job_started_at = time.time()
|
||||||
|
try:
|
||||||
|
# The actual request execution is handled by the API layer (so it can stream).
|
||||||
|
# This worker exists mainly to serialize jobs per device and provide queue metrics.
|
||||||
|
await asyncio.sleep(0) # placeholder
|
||||||
|
finally:
|
||||||
|
self.in_flight -= 1
|
||||||
|
self._q.task_done()
|
||||||
Loading…
Reference in New Issue