RoleMesh-Gateway/src/rolemesh_gateway/registry.py

114 lines
3.6 KiB
Python

from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field, HttpUrl
class NodeRegistration(BaseModel):
node_id: str
base_url: HttpUrl # OpenAI-compatible upstream base, e.g. http://10.0.0.12:8011
roles: List[str] # roles served by this node, e.g. ["planner", "writer"]
meta: Dict[str, str] = Field(default_factory=dict)
class NodeHeartbeat(BaseModel):
node_id: str
timestamp: float = Field(default_factory=lambda: time.time())
status: Dict[str, Any] = Field(default_factory=dict)
metrics: List[Dict[str, Any]] = Field(default_factory=list)
class RegisteredNode(BaseModel):
node_id: str
base_url: HttpUrl
roles: List[str]
meta: Dict[str, str] = Field(default_factory=dict)
status: Dict[str, Any] = Field(default_factory=dict)
registered_at: float = Field(default_factory=lambda: time.time())
last_seen: float = Field(default_factory=lambda: time.time())
class Registry:
"""
Minimal in-memory registry with optional JSON persistence.
NOTE: This is intentionally simple. For real deployments you likely want:
- persistence (sqlite/redis)
- auth (API key or mTLS)
- TTL + health checks
"""
def __init__(self, persist_path: Optional[Path] = None) -> None:
self._nodes: Dict[str, RegisteredNode] = {}
self._rr_counters: Dict[str, int] = {}
self._persist_path = persist_path
if self._persist_path:
self._load()
def _load(self) -> None:
if not self._persist_path or not self._persist_path.exists():
return
try:
raw = json.loads(self._persist_path.read_text())
for node_id, node_data in raw.get("nodes", {}).items():
self._nodes[node_id] = RegisteredNode.model_validate(node_data)
self._rr_counters = dict(raw.get("rr_counters", {}))
except Exception:
# If persistence is corrupted, start empty (do not crash the gateway).
self._nodes = {}
self._rr_counters = {}
def _save(self) -> None:
if not self._persist_path:
return
payload = {
"nodes": {k: v.model_dump(mode="json") for k, v in self._nodes.items()},
"rr_counters": self._rr_counters,
}
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
self._persist_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
def register(self, reg: NodeRegistration) -> RegisteredNode:
node = RegisteredNode(
node_id=reg.node_id,
base_url=reg.base_url,
roles=reg.roles,
meta=reg.meta,
last_seen=time.time(),
)
self._nodes[reg.node_id] = node
self._save()
return node
def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]:
n = self._nodes.get(hb.node_id)
if not n:
return None
n.status = {
**hb.status,
"timestamp": hb.timestamp,
"metrics": hb.metrics,
}
n.last_seen = time.time()
self._save()
return n
def list_nodes(self) -> List[RegisteredNode]:
return list(self._nodes.values())
def pick_node_for_role(self, role: str) -> Optional[RegisteredNode]:
candidates = [n for n in self._nodes.values() if role in n.roles]
if not candidates:
return None
idx = self._rr_counters.get(role, 0) % len(candidates)
self._rr_counters[role] = idx + 1
self._save()
return candidates[idx]