Change to having node agents and dispatcher structure
This commit is contained in:
parent
8537c6ef39
commit
b2e4518b3a
|
|
@ -64,3 +64,11 @@ This repository is a **preliminary scaffold**:
|
|||
## License
|
||||
|
||||
MIT. See `LICENSE`.
|
||||
|
||||
## Node Agent (per-host)
|
||||
|
||||
This repo also includes a **RoleMesh Node Agent** (`rolemesh-node-agent`) that can manage **persistent** `llama.cpp` servers (one per GPU) and report inventory/metrics back to the gateway.
|
||||
|
||||
- Sample config: `configs/node_agent.example.yaml`
|
||||
- Docs: `docs/NODE_AGENT.md`
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
node_id: "node-1"
|
||||
listen_host: "0.0.0.0"
|
||||
listen_port: 8091
|
||||
|
||||
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
|
||||
dispatcher_base_url: "http://127.0.0.1:8080"
|
||||
dispatcher_roles: ["planner", "coder"]
|
||||
heartbeat_interval_sec: 5
|
||||
|
||||
llama_server_bin: "llama-server"
|
||||
|
||||
model_roots:
|
||||
- "/models"
|
||||
|
||||
models:
|
||||
- model_id: "planner-gguf"
|
||||
path: "/models/SomePlannerModel.Q5_K_M.gguf"
|
||||
roles: ["planner"]
|
||||
default_ctx: 8192
|
||||
server_args:
|
||||
# Examples (llama.cpp flags differ by build/version):
|
||||
# c: 8192
|
||||
# n_gpu_layers: 60
|
||||
# threads: 8
|
||||
# parallel: 1
|
||||
# keep: true
|
||||
c: 8192
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
# Node Agent
|
||||
|
||||
The **RoleMesh Node Agent** runs on each compute host and manages **persistent** `llama.cpp` servers
|
||||
(one per device, e.g. one per GPU). It can:
|
||||
|
||||
- expose OpenAI-compatible endpoints locally (`/v1/models`, `/v1/chat/completions`)
|
||||
- register + heartbeat to the Dispatcher/Gateway (`/v1/nodes/register`, `/v1/nodes/heartbeat`)
|
||||
- report inventory + utilization (`/v1/node/inventory`)
|
||||
|
||||
## Persistent server model
|
||||
|
||||
For each GPU device, the node agent starts a dedicated `llama-server` process, pinned via
|
||||
environment variables (e.g. `CUDA_VISIBLE_DEVICES=0` for `gpu:0`) and bound to `127.0.0.1:<port>`.
|
||||
|
||||
Model switching is handled by **restart** in the scaffold.
|
||||
|
||||
## Backends
|
||||
|
||||
Adapters are implemented as runtime backends:
|
||||
|
||||
- `cuda`: scaffold implementation (NVIDIA via `nvidia-smi`)
|
||||
- `metal`, `rocm`, `sycl`, `vulkan`: stubs with placeholders for device discovery and metrics
|
||||
|
||||
The framework keeps scheduling decisions backend-agnostic by standardizing on:
|
||||
`DeviceRef` + `DeviceMetrics` + `ensure_server(...)`.
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
rolemesh-node-agent --config configs/node_agent.example.yaml
|
||||
```
|
||||
|
||||
## Registering
|
||||
|
||||
If `dispatcher_base_url` is set in the node-agent config, the node agent will periodically call:
|
||||
|
||||
- `POST <dispatcher>/v1/nodes/heartbeat` with latest device metrics.
|
||||
|
||||
Registration is currently manual from the node side (or can be added as a startup step).
|
||||
|
|
@ -18,6 +18,12 @@ dependencies = [
|
|||
"pyyaml>=6.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
[project.scripts]
|
||||
rolemesh-gateway = "rolemesh_gateway.cli:main"
|
||||
rolemesh-node-agent = "rolemesh_node_agent.cli:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"ruff>=0.4",
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from fastapi import APIRouter, Request
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
||||
from rolemesh_gateway.registry import NodeRegistration, Registry
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -145,3 +145,19 @@ async def register_node(request: Request) -> Dict[str, Any]:
|
|||
reg = NodeRegistration.model_validate(payload)
|
||||
node = registry.register(reg)
|
||||
return {"status": "ok", "node": node.model_dump(mode="json")}
|
||||
|
||||
|
||||
@router.post("/v1/nodes/heartbeat")
|
||||
async def nodes_heartbeat(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
Allow a node agent to push status/metrics updates.
|
||||
|
||||
SECURITY NOTE: Unauthenticated in scaffold. Add API-key gating or mTLS for real deployments.
|
||||
"""
|
||||
registry: Registry = request.app.state.registry
|
||||
payload = await request.json()
|
||||
hb = NodeHeartbeat.model_validate(payload)
|
||||
node = registry.heartbeat(hb)
|
||||
if not node:
|
||||
return JSONResponse(status_code=404, content={"error": "unknown_node", "node_id": hb.node_id})
|
||||
return {"status": "ok", "node": node.model_dump(mode="json")}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
|
||||
from rolemesh_gateway.main import create_app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="RoleMesh Gateway")
|
||||
p.add_argument("--config", required=True, help="Path to gateway YAML config.")
|
||||
p.add_argument("--host", default="0.0.0.0")
|
||||
p.add_argument("--port", type=int, default=8080)
|
||||
args = p.parse_args()
|
||||
|
||||
app = create_app(Path(args.config))
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
|
@ -15,11 +15,19 @@ class NodeRegistration(BaseModel):
|
|||
meta: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
||||
|
||||
class NodeHeartbeat(BaseModel):
|
||||
node_id: str
|
||||
timestamp: float = Field(default_factory=lambda: time.time())
|
||||
status: Dict[str, Any] = Field(default_factory=dict)
|
||||
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
class RegisteredNode(BaseModel):
|
||||
node_id: str
|
||||
base_url: HttpUrl
|
||||
roles: List[str]
|
||||
meta: Dict[str, str] = Field(default_factory=dict)
|
||||
status: Dict[str, Any] = Field(default_factory=dict)
|
||||
registered_at: float = Field(default_factory=lambda: time.time())
|
||||
last_seen: float = Field(default_factory=lambda: time.time())
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
__all__ = []
|
||||
|
|
@ -0,0 +1 @@
|
|||
__all__ = []
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Protocol
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DeviceRef:
|
||||
"""A stable identifier for a compute target on a node."""
|
||||
kind: str # 'gpu' | 'cpu' | 'tpu'
|
||||
backend: str # 'cuda' | 'metal' | 'rocm' | 'vulkan' | 'sycl' | 'cpu' | 'edgetpu'
|
||||
id: str # e.g. 'gpu:0', 'cpu:0'
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceMetrics:
|
||||
device: DeviceRef
|
||||
loaded_model_id: Optional[str] = None
|
||||
loaded_since: Optional[float] = None
|
||||
|
||||
# Utilization snapshot (best-effort, backend-specific)
|
||||
utilization_pct: Optional[float] = None
|
||||
mem_total_gb: Optional[float] = None
|
||||
mem_used_gb: Optional[float] = None
|
||||
|
||||
# Queueing / in-flight
|
||||
queue_depth: int = 0
|
||||
in_flight_jobs: int = 0
|
||||
|
||||
# Rolling performance hints (observed)
|
||||
median_ttft_ms: Optional[float] = None
|
||||
median_toks_per_sec: Optional[float] = None
|
||||
|
||||
|
||||
class RuntimeAdapter(Protocol):
|
||||
"""Backend-specific adapter for starting and managing persistent llama.cpp servers."""
|
||||
|
||||
backend_name: str # 'cuda'|'metal'|'rocm'|'vulkan'|'sycl'|'cpu'|...
|
||||
|
||||
async def discover_devices(self) -> list[DeviceRef]:
|
||||
...
|
||||
|
||||
async def get_metrics(self) -> list[DeviceMetrics]:
|
||||
...
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
"""Ensure a persistent server exists for (device, model). Returns base_url for OpenAI-compatible endpoints."""
|
||||
...
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
...
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return int(s.getsockname()[1])
|
||||
|
||||
|
||||
def _nvidia_smi_query() -> List[Dict[str, str]]:
|
||||
"""Best-effort GPU discovery/metrics via nvidia-smi. Returns [] if unavailable."""
|
||||
try:
|
||||
cmd = [
|
||||
"nvidia-smi",
|
||||
"--query-gpu=index,name,memory.total,memory.used,utilization.gpu",
|
||||
"--format=csv,noheader,nounits",
|
||||
]
|
||||
out = subprocess.check_output(cmd, text=True).strip()
|
||||
if not out:
|
||||
return []
|
||||
rows = []
|
||||
for line in out.splitlines():
|
||||
idx, name, mem_total, mem_used, util = [c.strip() for c in line.split(",")]
|
||||
rows.append(
|
||||
{
|
||||
"index": idx,
|
||||
"name": name,
|
||||
"mem_total_mb": mem_total,
|
||||
"mem_used_mb": mem_used,
|
||||
"util_pct": util,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ServerProc:
|
||||
device: DeviceRef
|
||||
model_id: str
|
||||
model_path: str
|
||||
port: int
|
||||
proc: subprocess.Popen
|
||||
started_at: float
|
||||
|
||||
|
||||
class CudaAdapter(RuntimeAdapter):
|
||||
"""
|
||||
CUDA adapter that manages one persistent ggml-org/llama.cpp 'llama-server' per GPU.
|
||||
|
||||
Assumptions:
|
||||
- 'llama-server' binary is available in PATH on the node.
|
||||
- Each server is bound to 127.0.0.1:<port>.
|
||||
- Model switching is performed by stopping/restarting the server for that GPU.
|
||||
(More advanced approaches can be added later.)
|
||||
|
||||
NOTE: This is a scaffold implementation intended to be extended with:
|
||||
- more robust health checks
|
||||
- structured logging
|
||||
- graceful shutdown and model warm pools
|
||||
"""
|
||||
|
||||
backend_name = "cuda"
|
||||
|
||||
def __init__(self, llama_server_bin: str = "llama-server") -> None:
|
||||
self._bin = llama_server_bin
|
||||
self._servers: Dict[str, _ServerProc] = {} # key: device.id
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def discover_devices(self) -> List[DeviceRef]:
|
||||
gpus = _nvidia_smi_query()
|
||||
devices: List[DeviceRef] = []
|
||||
for g in gpus:
|
||||
devices.append(DeviceRef(kind="gpu", backend="cuda", id=f"gpu:{g['index']}"))
|
||||
return devices
|
||||
|
||||
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||
gpus = _nvidia_smi_query()
|
||||
metrics: List[DeviceMetrics] = []
|
||||
for g in gpus:
|
||||
dev = DeviceRef(kind="gpu", backend="cuda", id=f"gpu:{g['index']}")
|
||||
m = DeviceMetrics(
|
||||
device=dev,
|
||||
utilization_pct=float(g["util_pct"]) if g.get("util_pct") else None,
|
||||
mem_total_gb=float(g["mem_total_mb"]) / 1024.0 if g.get("mem_total_mb") else None,
|
||||
mem_used_gb=float(g["mem_used_mb"]) / 1024.0 if g.get("mem_used_mb") else None,
|
||||
)
|
||||
sp = self._servers.get(dev.id)
|
||||
if sp:
|
||||
m.loaded_model_id = sp.model_id
|
||||
m.loaded_since = sp.started_at
|
||||
metrics.append(m)
|
||||
return metrics
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
if device.backend != "cuda" or device.kind != "gpu":
|
||||
raise ValueError("CudaAdapter can only manage cuda gpu devices.")
|
||||
|
||||
async with self._lock:
|
||||
existing = self._servers.get(device.id)
|
||||
if existing and existing.model_id == model_id and existing.proc.poll() is None:
|
||||
return f"http://127.0.0.1:{existing.port}"
|
||||
|
||||
# Stop old server if present
|
||||
if existing and existing.proc.poll() is None:
|
||||
existing.proc.terminate()
|
||||
try:
|
||||
existing.proc.wait(timeout=5)
|
||||
except Exception:
|
||||
existing.proc.kill()
|
||||
|
||||
port = _find_free_port()
|
||||
|
||||
# llama-server flag conventions vary; keep args configurable.
|
||||
# Common options include: -m <model>, --port, --host, -ngl (gpu layers), -c (ctx)
|
||||
# We also pin to a specific GPU using CUDA_VISIBLE_DEVICES for that process.
|
||||
env = os.environ.copy()
|
||||
gpu_index = device.id.split(":")[1]
|
||||
env["CUDA_VISIBLE_DEVICES"] = gpu_index
|
||||
|
||||
cmd = [self._bin, "-m", model_path, "--host", "127.0.0.1", "--port", str(port)]
|
||||
for k, v in (server_args or {}).items():
|
||||
if v is None or v is False:
|
||||
continue
|
||||
flag = str(k)
|
||||
if not flag.startswith("-"):
|
||||
flag = "--" + flag
|
||||
if v is True:
|
||||
cmd.append(flag)
|
||||
else:
|
||||
cmd.extend([flag, str(v)])
|
||||
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
self._servers[device.id] = _ServerProc(
|
||||
device=device,
|
||||
model_id=model_id,
|
||||
model_path=model_path,
|
||||
port=port,
|
||||
proc=proc,
|
||||
started_at=time.time(),
|
||||
)
|
||||
|
||||
# TODO: replace with a real readiness probe (e.g., GET /health on llama-server)
|
||||
await asyncio.sleep(0.25)
|
||||
return f"http://127.0.0.1:{port}"
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
async with self._lock:
|
||||
for sp in list(self._servers.values()):
|
||||
if sp.proc.poll() is None:
|
||||
sp.proc.terminate()
|
||||
try:
|
||||
sp.proc.wait(timeout=5)
|
||||
except Exception:
|
||||
sp.proc.kill()
|
||||
self._servers.clear()
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||
|
||||
|
||||
class MetalAdapter(RuntimeAdapter):
|
||||
"""
|
||||
Stub adapter for backend 'metal'.
|
||||
|
||||
Implementation notes:
|
||||
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||
compiled with the appropriate backend enabled (metal).
|
||||
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||
"""
|
||||
|
||||
backend_name = "metal"
|
||||
|
||||
async def discover_devices(self) -> List[DeviceRef]:
|
||||
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||
|
||||
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError("metal adapter not implemented in scaffold.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
return
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||
|
||||
|
||||
class RocmAdapter(RuntimeAdapter):
|
||||
"""
|
||||
Stub adapter for backend 'rocm'.
|
||||
|
||||
Implementation notes:
|
||||
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||
compiled with the appropriate backend enabled (rocm).
|
||||
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||
"""
|
||||
|
||||
backend_name = "rocm"
|
||||
|
||||
async def discover_devices(self) -> List[DeviceRef]:
|
||||
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||
|
||||
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError("rocm adapter not implemented in scaffold.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
return
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||
|
||||
|
||||
class SyclAdapter(RuntimeAdapter):
|
||||
"""
|
||||
Stub adapter for backend 'sycl'.
|
||||
|
||||
Implementation notes:
|
||||
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||
compiled with the appropriate backend enabled (sycl).
|
||||
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||
"""
|
||||
|
||||
backend_name = "sycl"
|
||||
|
||||
async def discover_devices(self) -> List[DeviceRef]:
|
||||
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||
|
||||
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError("sycl adapter not implemented in scaffold.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
return
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .base import DeviceMetrics, DeviceRef, RuntimeAdapter
|
||||
|
||||
|
||||
class VulkanAdapter(RuntimeAdapter):
|
||||
"""
|
||||
Stub adapter for backend 'vulkan'.
|
||||
|
||||
Implementation notes:
|
||||
- This adapter is expected to manage *persistent* ggml-org/llama.cpp 'llama-server' instances,
|
||||
compiled with the appropriate backend enabled (vulkan).
|
||||
- Device discovery and utilization reporting are backend/OS-specific and should be implemented
|
||||
using native tooling (e.g., Metal performance counters on macOS, rocm-smi on Linux, etc.).
|
||||
"""
|
||||
|
||||
backend_name = "vulkan"
|
||||
|
||||
async def discover_devices(self) -> List[DeviceRef]:
|
||||
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||
|
||||
async def get_metrics(self) -> List[DeviceMetrics]:
|
||||
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||
|
||||
async def ensure_server(self, device: DeviceRef, *, model_path: str, model_id: str, server_args: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError("vulkan adapter not implemented in scaffold.")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
return
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
|
||||
from .config import NodeAgentConfig
|
||||
from .main import create_app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(description="RoleMesh Node Agent")
|
||||
p.add_argument("--config", required=True, help="Path to node-agent YAML config.")
|
||||
args = p.parse_args()
|
||||
|
||||
cfg_path = Path(args.config)
|
||||
cfg = NodeAgentConfig.model_validate(yaml.safe_load(cfg_path.read_text()))
|
||||
app = create_app(cfg)
|
||||
uvicorn.run(app, host=cfg.listen_host, port=cfg.listen_port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class ModelEntry(BaseModel):
|
||||
model_id: str
|
||||
path: Path
|
||||
roles: List[str] = Field(default_factory=list)
|
||||
default_ctx: int = 8192
|
||||
server_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class NodeAgentConfig(BaseModel):
|
||||
node_id: str = "node-1"
|
||||
listen_host: str = "0.0.0.0"
|
||||
listen_port: int = 8091
|
||||
|
||||
# Where GGUF models live (used for inventory endpoints; not required if models are explicit)
|
||||
model_roots: List[Path] = Field(default_factory=list)
|
||||
|
||||
# Runtime backends enabled on this node
|
||||
enable_backends: List[Literal["cuda", "metal", "rocm", "vulkan", "sycl", "cpu"]] = Field(default_factory=lambda: ["cuda"])
|
||||
|
||||
# Explicit model catalog (recommended)
|
||||
models: List[ModelEntry] = Field(default_factory=list)
|
||||
|
||||
# Optional dispatcher registration/heartbeat
|
||||
dispatcher_base_url: Optional[HttpUrl] = None
|
||||
dispatcher_roles: List[str] = Field(default_factory=list)
|
||||
heartbeat_interval_sec: float = 5.0
|
||||
|
||||
# llama-server binary name/path
|
||||
llama_server_bin: str = "llama-server"
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from .config import ModelEntry
|
||||
|
||||
|
||||
def _sha256(path: Path, max_bytes: int = 2_000_000) -> str:
|
||||
"""Best-effort short hash; reads only the first max_bytes to keep inventory fast."""
|
||||
h = hashlib.sha256()
|
||||
with path.open("rb") as f:
|
||||
h.update(f.read(max_bytes))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def discover_gguf_models(roots: Iterable[Path]) -> List[Dict[str, str]]:
|
||||
out: List[Dict[str, str]] = []
|
||||
for r in roots:
|
||||
r = Path(r)
|
||||
if not r.exists():
|
||||
continue
|
||||
for p in r.rglob("*.gguf"):
|
||||
try:
|
||||
out.append(
|
||||
{
|
||||
"path": str(p),
|
||||
"name": p.name,
|
||||
"size_bytes": str(p.stat().st_size),
|
||||
"sha256_head": _sha256(p),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return out
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError # reuse gateway client
|
||||
from .adapters.cuda import CudaAdapter
|
||||
from .adapters.base import DeviceRef
|
||||
from .config import NodeAgentConfig
|
||||
from .inventory import discover_gguf_models
|
||||
|
||||
|
||||
def _error(message: str, code: str = "node_error", status_code: int = 500) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={"error": {"message": message, "type": "node_error", "code": code}},
|
||||
)
|
||||
|
||||
|
||||
def create_app(cfg: NodeAgentConfig) -> FastAPI:
|
||||
app = FastAPI(title="RoleMesh Node Agent", version="0.1.0")
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.http = httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3600.0))
|
||||
app.state.upstream = UpstreamClient(client=app.state.http)
|
||||
|
||||
# Adapters
|
||||
app.state.cuda = CudaAdapter(llama_server_bin=cfg.llama_server_bin)
|
||||
|
||||
# State: role -> (device, model)
|
||||
# This is intentionally simple for the scaffold: pick first GPU and first matching model.
|
||||
app.state.role_bindings: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
# optional: dispatcher registration loop
|
||||
if cfg.dispatcher_base_url and cfg.dispatcher_roles:
|
||||
asyncio.create_task(_heartbeat_loop(app))
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
await app.state.http.aclose()
|
||||
await app.state.cuda.shutdown()
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok", "node_id": cfg.node_id, "time": time.time()}
|
||||
|
||||
@app.get("/v1/node/inventory")
|
||||
async def inventory() -> Dict[str, Any]:
|
||||
devices = await app.state.cuda.discover_devices()
|
||||
metrics = [m.__dict__ | {"device": m.device.__dict__} for m in await app.state.cuda.get_metrics()]
|
||||
models = [{"model_id": m.model_id, "path": str(m.path), "roles": m.roles} for m in cfg.models]
|
||||
discovered = discover_gguf_models(cfg.model_roots)
|
||||
return {
|
||||
"node_id": cfg.node_id,
|
||||
"backends": cfg.enable_backends,
|
||||
"devices": [d.__dict__ for d in devices],
|
||||
"metrics": metrics,
|
||||
"models": models,
|
||||
"discovered_gguf": discovered,
|
||||
}
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
# Expose configured models as OpenAI models (node-local names).
|
||||
data = [{"id": m.model_id, "object": "model", "owned_by": cfg.node_id} for m in cfg.models]
|
||||
return {"object": "list", "data": data}
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> Any:
|
||||
body = await request.json()
|
||||
stream = bool(body.get("stream", False))
|
||||
model_id = body.get("model")
|
||||
if not model_id:
|
||||
return _error("Missing 'model' in request body.", code="bad_request", status_code=400)
|
||||
|
||||
# Find model entry
|
||||
model_entry = next((m for m in cfg.models if m.model_id == model_id), None)
|
||||
if not model_entry:
|
||||
return _error(f"Unknown model_id '{model_id}'.", code="unknown_model", status_code=404)
|
||||
|
||||
# Select device (first CUDA GPU for now)
|
||||
devices = await app.state.cuda.discover_devices()
|
||||
if not devices:
|
||||
return _error("No CUDA GPUs discovered on this node.", code="no_device", status_code=503)
|
||||
device = devices[0]
|
||||
|
||||
base_url = await app.state.cuda.ensure_server(
|
||||
device,
|
||||
model_path=str(model_entry.path),
|
||||
model_id=model_entry.model_id,
|
||||
server_args=model_entry.server_args,
|
||||
)
|
||||
|
||||
upstream = app.state.upstream
|
||||
try:
|
||||
if not stream:
|
||||
out = await upstream.chat_completions(base_url, body)
|
||||
return JSONResponse(status_code=200, content=out)
|
||||
else:
|
||||
async def gen():
|
||||
async for chunk in upstream.stream_chat_completions(base_url, body):
|
||||
yield chunk
|
||||
return StreamingResponse(gen(), media_type="text/event-stream")
|
||||
except UpstreamError as e:
|
||||
return _error(str(e), code="upstream_error", status_code=502)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def _heartbeat_loop(app: FastAPI) -> None:
|
||||
cfg: NodeAgentConfig = app.state.cfg
|
||||
http: httpx.AsyncClient = app.state.http
|
||||
while True:
|
||||
try:
|
||||
inv = await app.state.cuda.get_metrics()
|
||||
payload = {
|
||||
"node_id": cfg.node_id,
|
||||
"timestamp": time.time(),
|
||||
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
||||
}
|
||||
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
||||
await http.post(url, json=payload)
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
job_id: str
|
||||
submitted_at: float
|
||||
base_url: str
|
||||
body: Dict[str, Any]
|
||||
stream: bool
|
||||
|
||||
|
||||
class DeviceQueue:
|
||||
"""Simple FIFO queue per device; one worker per device for latency predictability."""
|
||||
|
||||
def __init__(self, *, client: httpx.AsyncClient, base_url: str) -> None:
|
||||
self._client = client
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._q: asyncio.Queue[Job] = asyncio.Queue()
|
||||
self._worker_task: Optional[asyncio.Task] = None
|
||||
self.in_flight: int = 0
|
||||
self.last_job_started_at: Optional[float] = None
|
||||
|
||||
def start(self) -> None:
|
||||
if self._worker_task is None:
|
||||
self._worker_task = asyncio.create_task(self._worker())
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._worker_task:
|
||||
self._worker_task.cancel()
|
||||
self._worker_task = None
|
||||
|
||||
async def submit(self, job: Job) -> Job:
|
||||
await self._q.put(job)
|
||||
return job
|
||||
|
||||
@property
|
||||
def depth(self) -> int:
|
||||
return self._q.qsize()
|
||||
|
||||
async def _worker(self) -> None:
|
||||
while True:
|
||||
job = await self._q.get()
|
||||
self.in_flight += 1
|
||||
self.last_job_started_at = time.time()
|
||||
try:
|
||||
# The actual request execution is handled by the API layer (so it can stream).
|
||||
# This worker exists mainly to serialize jobs per device and provide queue metrics.
|
||||
await asyncio.sleep(0) # placeholder
|
||||
finally:
|
||||
self.in_flight -= 1
|
||||
self._q.task_done()
|
||||
Loading…
Reference in New Issue