from __future__ import annotations import json import random import time from pathlib import Path from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, HttpUrl class NodeRegistration(BaseModel): node_id: str base_url: HttpUrl # OpenAI-compatible upstream base, e.g. http://10.0.0.12:8011 roles: List[str] # roles served by this node, e.g. ["planner", "writer"] meta: Dict[str, str] = Field(default_factory=dict) class NodeHeartbeat(BaseModel): node_id: str timestamp: float = Field(default_factory=lambda: time.time()) status: Dict[str, Any] = Field(default_factory=dict) metrics: List[Dict[str, Any]] = Field(default_factory=list) class RegisteredNode(BaseModel): node_id: str base_url: HttpUrl roles: List[str] meta: Dict[str, str] = Field(default_factory=dict) status: Dict[str, Any] = Field(default_factory=dict) registered_at: float = Field(default_factory=lambda: time.time()) last_seen: float = Field(default_factory=lambda: time.time()) class Registry: """ Minimal in-memory registry with optional JSON persistence. NOTE: This is intentionally simple. For real deployments you likely want: - persistence (sqlite/redis) - auth (API key or mTLS) - TTL + health checks """ def __init__(self, persist_path: Optional[Path] = None, stale_after_s: float = 30.0) -> None: self._nodes: Dict[str, RegisteredNode] = {} self._rr_counters: Dict[str, int] = {} self._persist_path = persist_path self._stale_after_s = stale_after_s if self._persist_path: self._load() def _load(self) -> None: if not self._persist_path or not self._persist_path.exists(): return try: raw = json.loads(self._persist_path.read_text()) for node_id, node_data in raw.get("nodes", {}).items(): self._nodes[node_id] = RegisteredNode.model_validate(node_data) self._rr_counters = dict(raw.get("rr_counters", {})) except Exception: # If persistence is corrupted, start empty (do not crash the gateway). self._nodes = {} self._rr_counters = {} def _save(self) -> None: if not self._persist_path: return payload = { "nodes": {k: v.model_dump(mode="json") for k, v in self._nodes.items()}, "rr_counters": self._rr_counters, } self._persist_path.parent.mkdir(parents=True, exist_ok=True) self._persist_path.write_text(json.dumps(payload, indent=2, sort_keys=True)) def register(self, reg: NodeRegistration) -> RegisteredNode: node = RegisteredNode( node_id=reg.node_id, base_url=reg.base_url, roles=reg.roles, meta=reg.meta, last_seen=time.time(), ) self._nodes[reg.node_id] = node self._save() return node def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]: n = self._nodes.get(hb.node_id) if not n: return None n.status = { **hb.status, "timestamp": hb.timestamp, "metrics": hb.metrics, } n.last_seen = time.time() self._save() return n def is_stale(self, node: RegisteredNode, now: Optional[float] = None) -> bool: if self._stale_after_s <= 0: return False now = time.time() if now is None else now return (now - node.last_seen) > self._stale_after_s def list_nodes(self, *, include_stale: bool = True) -> List[RegisteredNode]: nodes = list(self._nodes.values()) if include_stale: return nodes now = time.time() return [node for node in nodes if not self.is_stale(node, now=now)] def nodes_for_role(self, role: str, *, include_stale: bool = False) -> List[RegisteredNode]: nodes = self.list_nodes(include_stale=include_stale) return [node for node in nodes if role in node.roles] def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]: candidates = self.nodes_for_role(role, include_stale=False) if not candidates: return None if strategy == "random": return random.choice(candidates) idx = self._rr_counters.get(role, 0) % len(candidates) self._rr_counters[role] = idx + 1 self._save() return candidates[idx]