132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import random
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Field, HttpUrl
|
|
|
|
|
|
class NodeRegistration(BaseModel):
|
|
node_id: str
|
|
base_url: HttpUrl # OpenAI-compatible upstream base, e.g. http://10.0.0.12:8011
|
|
roles: List[str] # roles served by this node, e.g. ["planner", "writer"]
|
|
meta: Dict[str, str] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
class NodeHeartbeat(BaseModel):
|
|
node_id: str
|
|
timestamp: float = Field(default_factory=lambda: time.time())
|
|
status: Dict[str, Any] = Field(default_factory=dict)
|
|
metrics: List[Dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
class RegisteredNode(BaseModel):
|
|
node_id: str
|
|
base_url: HttpUrl
|
|
roles: List[str]
|
|
meta: Dict[str, str] = Field(default_factory=dict)
|
|
status: Dict[str, Any] = Field(default_factory=dict)
|
|
registered_at: float = Field(default_factory=lambda: time.time())
|
|
last_seen: float = Field(default_factory=lambda: time.time())
|
|
|
|
|
|
class Registry:
|
|
"""
|
|
Minimal in-memory registry with optional JSON persistence.
|
|
|
|
NOTE: This is intentionally simple. For real deployments you likely want:
|
|
- persistence (sqlite/redis)
|
|
- auth (API key or mTLS)
|
|
- TTL + health checks
|
|
"""
|
|
|
|
def __init__(self, persist_path: Optional[Path] = None, stale_after_s: float = 30.0) -> None:
|
|
self._nodes: Dict[str, RegisteredNode] = {}
|
|
self._rr_counters: Dict[str, int] = {}
|
|
self._persist_path = persist_path
|
|
self._stale_after_s = stale_after_s
|
|
|
|
if self._persist_path:
|
|
self._load()
|
|
|
|
def _load(self) -> None:
|
|
if not self._persist_path or not self._persist_path.exists():
|
|
return
|
|
try:
|
|
raw = json.loads(self._persist_path.read_text())
|
|
for node_id, node_data in raw.get("nodes", {}).items():
|
|
self._nodes[node_id] = RegisteredNode.model_validate(node_data)
|
|
self._rr_counters = dict(raw.get("rr_counters", {}))
|
|
except Exception:
|
|
# If persistence is corrupted, start empty (do not crash the gateway).
|
|
self._nodes = {}
|
|
self._rr_counters = {}
|
|
|
|
def _save(self) -> None:
|
|
if not self._persist_path:
|
|
return
|
|
payload = {
|
|
"nodes": {k: v.model_dump(mode="json") for k, v in self._nodes.items()},
|
|
"rr_counters": self._rr_counters,
|
|
}
|
|
self._persist_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._persist_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
|
|
|
|
def register(self, reg: NodeRegistration) -> RegisteredNode:
|
|
node = RegisteredNode(
|
|
node_id=reg.node_id,
|
|
base_url=reg.base_url,
|
|
roles=reg.roles,
|
|
meta=reg.meta,
|
|
last_seen=time.time(),
|
|
)
|
|
self._nodes[reg.node_id] = node
|
|
self._save()
|
|
return node
|
|
|
|
def heartbeat(self, hb: NodeHeartbeat) -> Optional[RegisteredNode]:
|
|
n = self._nodes.get(hb.node_id)
|
|
if not n:
|
|
return None
|
|
n.status = {
|
|
**hb.status,
|
|
"timestamp": hb.timestamp,
|
|
"metrics": hb.metrics,
|
|
}
|
|
n.last_seen = time.time()
|
|
self._save()
|
|
return n
|
|
|
|
def is_stale(self, node: RegisteredNode, now: Optional[float] = None) -> bool:
|
|
if self._stale_after_s <= 0:
|
|
return False
|
|
now = time.time() if now is None else now
|
|
return (now - node.last_seen) > self._stale_after_s
|
|
|
|
def list_nodes(self, *, include_stale: bool = True) -> List[RegisteredNode]:
|
|
nodes = list(self._nodes.values())
|
|
if include_stale:
|
|
return nodes
|
|
now = time.time()
|
|
return [node for node in nodes if not self.is_stale(node, now=now)]
|
|
|
|
def nodes_for_role(self, role: str, *, include_stale: bool = False) -> List[RegisteredNode]:
|
|
nodes = self.list_nodes(include_stale=include_stale)
|
|
return [node for node in nodes if role in node.roles]
|
|
|
|
def pick_node_for_role(self, role: str, strategy: str = "round_robin") -> Optional[RegisteredNode]:
|
|
candidates = self.nodes_for_role(role, include_stale=False)
|
|
if not candidates:
|
|
return None
|
|
if strategy == "random":
|
|
return random.choice(candidates)
|
|
idx = self._rr_counters.get(role, 0) % len(candidates)
|
|
self._rr_counters[role] = idx + 1
|
|
self._save()
|
|
return candidates[idx]
|