802 lines
33 KiB
Python
802 lines
33 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from .models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest
|
|
from .request_policy import effective_chat_request_policy, select_target_asset
|
|
|
|
|
|
def _json_dumps(value: object) -> str:
|
|
return json.dumps(value, sort_keys=True)
|
|
|
|
|
|
class Registry:
|
|
def __init__(self, db_path: str | Path) -> None:
|
|
self.db_path = Path(db_path)
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._init_db()
|
|
|
|
def _connect(self) -> sqlite3.Connection:
|
|
conn = sqlite3.connect(self.db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
def _init_db(self) -> None:
|
|
with self._connect() as conn:
|
|
conn.executescript(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS hosts (
|
|
host_id TEXT PRIMARY KEY,
|
|
display_name TEXT,
|
|
address TEXT NOT NULL,
|
|
labels_json TEXT NOT NULL,
|
|
capabilities_json TEXT NOT NULL,
|
|
resources_json TEXT NOT NULL,
|
|
status_state TEXT NOT NULL DEFAULT 'online',
|
|
last_seen REAL NOT NULL,
|
|
metrics_json TEXT NOT NULL DEFAULT '{}'
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS services (
|
|
service_id TEXT PRIMARY KEY,
|
|
host_id TEXT NOT NULL,
|
|
kind TEXT NOT NULL,
|
|
protocol TEXT NOT NULL,
|
|
endpoint TEXT NOT NULL,
|
|
runtime_json TEXT NOT NULL,
|
|
assets_json TEXT NOT NULL,
|
|
state_json TEXT NOT NULL,
|
|
observed_json TEXT NOT NULL,
|
|
updated_at REAL NOT NULL,
|
|
FOREIGN KEY(host_id) REFERENCES hosts(host_id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS roles (
|
|
role_id TEXT PRIMARY KEY,
|
|
display_name TEXT,
|
|
description TEXT,
|
|
operation TEXT NOT NULL,
|
|
modality TEXT NOT NULL,
|
|
prompt_policy_json TEXT NOT NULL,
|
|
routing_policy_json TEXT NOT NULL,
|
|
updated_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS benchmark_samples (
|
|
benchmark_id TEXT PRIMARY KEY,
|
|
service_id TEXT NOT NULL,
|
|
asset_id TEXT,
|
|
workload TEXT NOT NULL,
|
|
observed_at REAL NOT NULL,
|
|
results_json TEXT NOT NULL
|
|
);
|
|
"""
|
|
)
|
|
|
|
def register_host(self, reg: HostRegistration) -> dict:
|
|
now = time.time()
|
|
with self._connect() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO hosts (
|
|
host_id, display_name, address, labels_json, capabilities_json,
|
|
resources_json, status_state, last_seen, metrics_json
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, 'online', ?, '{}')
|
|
ON CONFLICT(host_id) DO UPDATE SET
|
|
display_name=excluded.display_name,
|
|
address=excluded.address,
|
|
labels_json=excluded.labels_json,
|
|
capabilities_json=excluded.capabilities_json,
|
|
resources_json=excluded.resources_json,
|
|
status_state='online',
|
|
last_seen=excluded.last_seen
|
|
""",
|
|
(
|
|
reg.host_id,
|
|
reg.display_name,
|
|
reg.address,
|
|
_json_dumps(reg.labels),
|
|
_json_dumps(reg.capabilities),
|
|
_json_dumps(reg.resources),
|
|
now,
|
|
),
|
|
)
|
|
self._replace_services(conn, reg.host_id, reg.services, now)
|
|
return self.get_host(reg.host_id)
|
|
|
|
def heartbeat_host(self, hb: HostHeartbeat) -> dict | None:
|
|
now = time.time()
|
|
with self._connect() as conn:
|
|
cur = conn.execute(
|
|
"SELECT host_id FROM hosts WHERE host_id = ?",
|
|
(hb.host_id,),
|
|
)
|
|
if cur.fetchone() is None:
|
|
return None
|
|
conn.execute(
|
|
"""
|
|
UPDATE hosts
|
|
SET status_state = ?, last_seen = ?, metrics_json = ?
|
|
WHERE host_id = ?
|
|
""",
|
|
(
|
|
hb.status.state,
|
|
now,
|
|
_json_dumps(hb.metrics),
|
|
hb.host_id,
|
|
),
|
|
)
|
|
if hb.services:
|
|
self._replace_services(conn, hb.host_id, hb.services, now)
|
|
return self.get_host(hb.host_id)
|
|
|
|
def _replace_services(
|
|
self,
|
|
conn: sqlite3.Connection,
|
|
host_id: str,
|
|
services: list[RegisteredService],
|
|
now: float,
|
|
) -> None:
|
|
conn.execute("DELETE FROM services WHERE host_id = ?", (host_id,))
|
|
for service in services:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO services (
|
|
service_id, host_id, kind, protocol, endpoint,
|
|
runtime_json, assets_json, state_json, observed_json, updated_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
service.service_id,
|
|
host_id,
|
|
service.kind,
|
|
service.protocol,
|
|
service.endpoint,
|
|
_json_dumps(service.runtime.model_dump()),
|
|
_json_dumps([asset.model_dump() for asset in service.assets]),
|
|
_json_dumps(service.state.model_dump()),
|
|
_json_dumps(service.observed.model_dump()),
|
|
now,
|
|
),
|
|
)
|
|
|
|
def get_host(self, host_id: str) -> dict | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT * FROM hosts WHERE host_id = ?", (host_id,)).fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._host_row_to_dict(row)
|
|
|
|
def upsert_roles(self, roles: list[RoleProfile]) -> list[dict]:
|
|
now = time.time()
|
|
with self._connect() as conn:
|
|
for role in roles:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO roles (
|
|
role_id, display_name, description, operation, modality,
|
|
prompt_policy_json, routing_policy_json, updated_at
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(role_id) DO UPDATE SET
|
|
display_name=excluded.display_name,
|
|
description=excluded.description,
|
|
operation=excluded.operation,
|
|
modality=excluded.modality,
|
|
prompt_policy_json=excluded.prompt_policy_json,
|
|
routing_policy_json=excluded.routing_policy_json,
|
|
updated_at=excluded.updated_at
|
|
""",
|
|
(
|
|
role.role_id,
|
|
role.display_name,
|
|
role.description,
|
|
role.operation,
|
|
role.modality,
|
|
_json_dumps(role.prompt_policy.model_dump()),
|
|
_json_dumps(role.routing_policy.model_dump()),
|
|
now,
|
|
),
|
|
)
|
|
return self.list_roles()
|
|
|
|
def get_role(self, role_id: str) -> dict | None:
|
|
with self._connect() as conn:
|
|
row = conn.execute("SELECT * FROM roles WHERE role_id = ?", (role_id,)).fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._role_row_to_dict(row)
|
|
|
|
def list_roles(self) -> list[dict]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute("SELECT * FROM roles ORDER BY role_id").fetchall()
|
|
return [self._role_row_to_dict(row) for row in rows]
|
|
|
|
def list_hosts(self) -> list[dict]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute("SELECT * FROM hosts ORDER BY host_id").fetchall()
|
|
return [self._host_row_to_dict(row) for row in rows]
|
|
|
|
def list_services(self) -> list[dict]:
|
|
with self._connect() as conn:
|
|
rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall()
|
|
return [self._service_row_to_dict(row) for row in rows]
|
|
|
|
def upsert_benchmark_samples(self, samples: list[BenchmarkSample]) -> list[dict]:
|
|
with self._connect() as conn:
|
|
for sample in samples:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO benchmark_samples (
|
|
benchmark_id, service_id, asset_id, workload, observed_at, results_json
|
|
)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(benchmark_id) DO UPDATE SET
|
|
service_id=excluded.service_id,
|
|
asset_id=excluded.asset_id,
|
|
workload=excluded.workload,
|
|
observed_at=excluded.observed_at,
|
|
results_json=excluded.results_json
|
|
""",
|
|
(
|
|
sample.benchmark_id,
|
|
sample.service_id,
|
|
sample.asset_id,
|
|
sample.workload,
|
|
sample.observed_at,
|
|
_json_dumps(sample.results),
|
|
),
|
|
)
|
|
return self.list_benchmark_samples()
|
|
|
|
def list_benchmark_samples(self, *, service_id: str | None = None, workload: str | None = None) -> list[dict]:
|
|
query = "SELECT * FROM benchmark_samples"
|
|
clauses = []
|
|
params: list[object] = []
|
|
if service_id:
|
|
clauses.append("service_id = ?")
|
|
params.append(service_id)
|
|
if workload:
|
|
clauses.append("workload = ?")
|
|
params.append(workload)
|
|
if clauses:
|
|
query += " WHERE " + " AND ".join(clauses)
|
|
query += " ORDER BY observed_at DESC, benchmark_id"
|
|
with self._connect() as conn:
|
|
rows = conn.execute(query, params).fetchall()
|
|
return [self._benchmark_row_to_dict(row) for row in rows]
|
|
|
|
def list_client_models(self) -> list[dict]:
|
|
services = self.list_services()
|
|
roles = self.list_roles()
|
|
items: list[dict] = []
|
|
|
|
for service in services:
|
|
if not service["state"].get("accept_requests", True):
|
|
continue
|
|
if service["state"].get("health") != "healthy":
|
|
continue
|
|
item = {
|
|
"id": service["service_id"],
|
|
"object": "model",
|
|
"owned_by": service["host_id"],
|
|
"geniehive": self._service_metadata(service),
|
|
}
|
|
items.append(item)
|
|
for asset in service["assets"]:
|
|
asset_id = asset.get("asset_id")
|
|
if not asset_id:
|
|
continue
|
|
items.append(
|
|
{
|
|
"id": asset_id,
|
|
"object": "model",
|
|
"owned_by": service["host_id"],
|
|
"geniehive": self._service_metadata(service, requested_model=asset_id) | {"route_type": "asset", "asset_id": asset_id},
|
|
}
|
|
)
|
|
|
|
for role in roles:
|
|
matching_services = [
|
|
service
|
|
for service in services
|
|
if service["kind"] == role["operation"]
|
|
and service["state"].get("accept_requests", True)
|
|
and service["state"].get("health") == "healthy"
|
|
]
|
|
loaded_count = sum(1 for service in matching_services if any(asset.get("loaded") for asset in service["assets"]))
|
|
latencies = [
|
|
service["observed"].get("p50_latency_ms")
|
|
for service in matching_services
|
|
if service["observed"].get("p50_latency_ms") is not None
|
|
]
|
|
best_latency_ms = min(latencies) if latencies else None
|
|
items.append(
|
|
{
|
|
"id": role["role_id"],
|
|
"object": "model",
|
|
"owned_by": "geniehive-role",
|
|
"geniehive": {
|
|
"route_type": "role",
|
|
"role_id": role["role_id"],
|
|
"display_name": role["display_name"],
|
|
"operation": role["operation"],
|
|
"modality": role["modality"],
|
|
"healthy_target_count": len(matching_services),
|
|
"loaded_target_count": loaded_count,
|
|
"best_p50_latency_ms": best_latency_ms,
|
|
"offload_hint": self._offload_hint(
|
|
operation=role["operation"],
|
|
loaded_count=loaded_count,
|
|
best_latency_ms=best_latency_ms,
|
|
),
|
|
"routing_policy": role["routing_policy"],
|
|
"effective_request_policy": self._effective_request_policy(
|
|
requested_model=role["role_id"],
|
|
role=role,
|
|
service=self.resolve_route(role["role_id"], kind=role["operation"]).get("service") if matching_services else None,
|
|
),
|
|
},
|
|
}
|
|
)
|
|
|
|
deduped: dict[str, dict] = {}
|
|
for item in items:
|
|
deduped[item["id"]] = item
|
|
return [deduped[key] for key in sorted(deduped)]
|
|
|
|
def resolve_route(self, requested_model: str, *, kind: str | None = None) -> dict | None:
|
|
direct = self._resolve_direct(requested_model, kind=kind)
|
|
if direct is not None:
|
|
return {"match_type": "direct", **direct}
|
|
|
|
role = self.get_role(requested_model)
|
|
if role is None:
|
|
return None
|
|
|
|
matched_kind = kind or role["operation"]
|
|
candidates = [
|
|
service
|
|
for service in self.list_services()
|
|
if service["kind"] == matched_kind
|
|
and service["state"].get("accept_requests", True)
|
|
and service["state"].get("health") == "healthy"
|
|
]
|
|
if not candidates:
|
|
return {"match_type": "role", "role": role, "service": None}
|
|
|
|
preferred_families = [family.lower() for family in role["routing_policy"].get("preferred_families", [])]
|
|
|
|
def score(service: dict) -> tuple[int, int, float, str]:
|
|
loaded = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
|
|
family_match = 0
|
|
if preferred_families:
|
|
asset_names = " ".join(asset.get("asset_id", "") for asset in service["assets"]).lower()
|
|
family_match = 1 if any(family in asset_names for family in preferred_families) else 0
|
|
latency = service["observed"].get("p50_latency_ms")
|
|
latency_score = float(latency) if latency is not None else float("inf")
|
|
return (family_match, loaded, -latency_score, service["service_id"])
|
|
|
|
if role["routing_policy"].get("require_loaded"):
|
|
loaded_candidates = [service for service in candidates if any(asset.get("loaded") for asset in service["assets"])]
|
|
if loaded_candidates:
|
|
candidates = loaded_candidates
|
|
|
|
service = max(candidates, key=score)
|
|
return {"match_type": "role", "role": role, "service": service}
|
|
|
|
def match_routes(self, request: RouteMatchRequest) -> dict:
|
|
tasks = [task.strip() for task in ([request.task] if request.task else []) + request.tasks if task and task.strip()]
|
|
workloads = [value.strip() for value in ([request.workload] if request.workload else []) + request.workloads if value and value.strip()]
|
|
kind = request.kind
|
|
modality = request.modality
|
|
services = [
|
|
service
|
|
for service in self.list_services()
|
|
if (kind is None or service["kind"] == kind)
|
|
and service["state"].get("accept_requests", True)
|
|
and service["state"].get("health") == "healthy"
|
|
]
|
|
roles = [
|
|
role
|
|
for role in self.list_roles()
|
|
if (kind is None or role["operation"] == kind)
|
|
and (modality is None or role["modality"] == modality)
|
|
]
|
|
|
|
candidates: list[dict] = []
|
|
for role in roles:
|
|
resolved = self.resolve_route(role["role_id"], kind=role["operation"])
|
|
service = resolved["service"] if resolved is not None else None
|
|
candidate = self._score_role_candidate(role, service, tasks, workloads)
|
|
candidates.append(candidate)
|
|
|
|
if request.include_direct_services:
|
|
for service in services:
|
|
candidates.append(self._score_service_candidate(service, tasks, workloads))
|
|
|
|
candidates.sort(
|
|
key=lambda item: (
|
|
-item["score"],
|
|
item["candidate_type"] != "role",
|
|
item["candidate_id"],
|
|
)
|
|
)
|
|
limit = max(1, request.limit)
|
|
return {
|
|
"status": "ok",
|
|
"task_count": len(tasks),
|
|
"tasks": tasks,
|
|
"workloads": workloads,
|
|
"kind": kind,
|
|
"modality": modality,
|
|
"candidates": candidates[:limit],
|
|
}
|
|
|
|
def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None:
|
|
candidates = []
|
|
for service in self.list_services():
|
|
if kind is not None and service["kind"] != kind:
|
|
continue
|
|
if not service["state"].get("accept_requests", True):
|
|
continue
|
|
if service["state"].get("health") != "healthy":
|
|
continue
|
|
asset_ids = {asset.get("asset_id") for asset in service["assets"]}
|
|
if service["service_id"] == requested_model or requested_model in asset_ids:
|
|
candidates.append(service)
|
|
if not candidates:
|
|
return None
|
|
|
|
def score(service: dict) -> tuple[int, float, str]:
|
|
loaded = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
|
|
latency = service["observed"].get("p50_latency_ms")
|
|
latency_score = float(latency) if latency is not None else float("inf")
|
|
return (loaded, -latency_score, service["service_id"])
|
|
|
|
service = max(candidates, key=score)
|
|
return {"service": service}
|
|
|
|
def _score_role_candidate(self, role: dict, service: dict | None, tasks: list[str], workloads: list[str]) -> dict:
|
|
task_tokens = _tokenize_tasks(tasks)
|
|
text_parts = [
|
|
role.get("role_id", ""),
|
|
role.get("display_name", "") or "",
|
|
role.get("description", "") or "",
|
|
(role.get("prompt_policy") or {}).get("system_prompt", "") or "",
|
|
" ".join((role.get("routing_policy") or {}).get("preferred_families", [])),
|
|
" ".join((role.get("routing_policy") or {}).get("preferred_labels", [])),
|
|
]
|
|
text_score = _overlap_score(task_tokens, _tokenize_text(" ".join(text_parts)))
|
|
preferred_families = [family.lower() for family in (role.get("routing_policy") or {}).get("preferred_families", [])]
|
|
family_score = 0.0
|
|
if service is not None and preferred_families:
|
|
asset_names = " ".join(asset.get("asset_id", "") for asset in service["assets"]).lower()
|
|
if any(family in asset_names for family in preferred_families):
|
|
family_score = 1.0
|
|
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
|
|
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
|
|
|
|
score = min(1.0, 0.30 * text_score + 0.15 * family_score + 0.30 * runtime_score + 0.25 * benchmark_score)
|
|
reasons = []
|
|
if text_score > 0:
|
|
reasons.append("task text overlaps role description or policy")
|
|
if family_score > 0:
|
|
reasons.append("resolved service matches role preferred model family")
|
|
reasons.extend(runtime_reasons)
|
|
reasons.extend(benchmark_reasons)
|
|
if service is None:
|
|
reasons.append("no healthy service currently resolves for this role")
|
|
return {
|
|
"candidate_type": "role",
|
|
"candidate_id": role["role_id"],
|
|
"operation": role["operation"],
|
|
"score": round(score, 4),
|
|
"reasons": reasons,
|
|
"signals": {
|
|
"task_overlap": round(text_score, 4),
|
|
"preferred_family_match": family_score,
|
|
**runtime_signals,
|
|
**benchmark_signals,
|
|
},
|
|
"role": role,
|
|
"service": service,
|
|
}
|
|
|
|
def _score_service_candidate(self, service: dict, tasks: list[str], workloads: list[str]) -> dict:
|
|
task_tokens = _tokenize_tasks(tasks)
|
|
service_text = " ".join(
|
|
[
|
|
service.get("service_id", ""),
|
|
service.get("host_id", ""),
|
|
" ".join(asset.get("asset_id", "") for asset in service.get("assets", [])),
|
|
" ".join(f"{key} {value}" for key, value in (service.get("runtime") or {}).items() if value),
|
|
]
|
|
)
|
|
text_score = _overlap_score(task_tokens, _tokenize_text(service_text))
|
|
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
|
|
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
|
|
score = min(1.0, 0.20 * text_score + 0.45 * runtime_score + 0.35 * benchmark_score)
|
|
reasons = []
|
|
if text_score > 0:
|
|
reasons.append("task text overlaps service or asset metadata")
|
|
reasons.extend(runtime_reasons)
|
|
reasons.extend(benchmark_reasons)
|
|
return {
|
|
"candidate_type": "service",
|
|
"candidate_id": service["service_id"],
|
|
"operation": service["kind"],
|
|
"score": round(score, 4),
|
|
"reasons": reasons,
|
|
"signals": {
|
|
"task_overlap": round(text_score, 4),
|
|
**runtime_signals,
|
|
**benchmark_signals,
|
|
},
|
|
"role": None,
|
|
"service": service,
|
|
}
|
|
|
|
@staticmethod
|
|
def _runtime_signals(service: dict | None) -> tuple[float, list[str], dict[str, object]]:
|
|
if service is None:
|
|
return 0.0, [], {"loaded": False, "p50_latency_ms": None, "tokens_per_sec": None}
|
|
loaded = any(asset.get("loaded") for asset in service.get("assets", []))
|
|
latency = service["observed"].get("p50_latency_ms")
|
|
tokens_per_sec = service["observed"].get("tokens_per_sec")
|
|
queue_depth = service["observed"].get("queue_depth")
|
|
|
|
score = 0.0
|
|
reasons: list[str] = []
|
|
if loaded:
|
|
score += 0.35
|
|
reasons.append("service already has a loaded asset")
|
|
if latency is not None:
|
|
if latency <= 1500:
|
|
score += 0.30
|
|
reasons.append("low observed latency")
|
|
elif latency <= 4000:
|
|
score += 0.18
|
|
reasons.append("moderate observed latency")
|
|
else:
|
|
score += 0.05
|
|
reasons.append("high but usable latency")
|
|
if tokens_per_sec is not None:
|
|
if tokens_per_sec >= 20:
|
|
score += 0.20
|
|
reasons.append("good observed throughput")
|
|
elif tokens_per_sec >= 8:
|
|
score += 0.10
|
|
reasons.append("usable observed throughput")
|
|
if queue_depth is not None and queue_depth > 0:
|
|
score -= min(0.15, 0.03 * queue_depth)
|
|
reasons.append("current queue depth reduces suitability")
|
|
return max(0.0, min(1.0, score)), reasons, {
|
|
"loaded": loaded,
|
|
"p50_latency_ms": latency,
|
|
"tokens_per_sec": tokens_per_sec,
|
|
"queue_depth": queue_depth,
|
|
}
|
|
|
|
def _benchmark_signals(self, service: dict | None, tasks: list[str], workloads: list[str]) -> tuple[float, list[str], dict[str, object]]:
|
|
if service is None:
|
|
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
|
|
samples = self.list_benchmark_samples(service_id=service["service_id"])
|
|
if not samples:
|
|
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
|
|
|
|
query_tokens = _tokenize_tasks(tasks + workloads)
|
|
best_overlap = 0.0
|
|
best_quality = 0.0
|
|
matched_count = 0
|
|
for sample in samples:
|
|
workload_tokens = _tokenize_text(sample["workload"])
|
|
overlap = _overlap_score(query_tokens, workload_tokens) if query_tokens else 0.0
|
|
if workloads and sample["workload"] in workloads:
|
|
overlap = max(overlap, 1.0)
|
|
quality = _benchmark_quality_score(sample["results"])
|
|
if overlap > 0 or not query_tokens:
|
|
matched_count += 1
|
|
best_overlap = max(best_overlap, overlap)
|
|
best_quality = max(best_quality, quality)
|
|
|
|
score = 0.55 * best_overlap + 0.45 * best_quality if matched_count else 0.0
|
|
reasons: list[str] = []
|
|
if matched_count:
|
|
reasons.append("recent benchmark sample matches requested workload or task shape")
|
|
if best_quality >= 0.6:
|
|
reasons.append("benchmark results indicate strong empirical fit")
|
|
elif matched_count:
|
|
reasons.append("benchmark results indicate limited but relevant empirical fit")
|
|
return min(1.0, score), reasons, {
|
|
"benchmark_match_count": matched_count,
|
|
"best_workload_overlap": round(best_overlap, 4),
|
|
"benchmark_quality_score": round(best_quality, 4) if matched_count else None,
|
|
}
|
|
|
|
def cluster_health(self, stale_after_s: float) -> dict:
|
|
hosts = self.list_hosts()
|
|
services = self.list_services()
|
|
now = time.time()
|
|
online = 0
|
|
stale = 0
|
|
for host in hosts:
|
|
is_stale = (now - host["status"]["last_seen"]) > stale_after_s
|
|
if is_stale:
|
|
stale += 1
|
|
elif host["status"]["state"] == "online":
|
|
online += 1
|
|
healthy_services = sum(1 for service in services if service["state"].get("health") == "healthy")
|
|
return {
|
|
"status": "ok",
|
|
"host_count": len(hosts),
|
|
"online_host_count": online,
|
|
"stale_host_count": stale,
|
|
"service_count": len(services),
|
|
"healthy_service_count": healthy_services,
|
|
}
|
|
|
|
@staticmethod
|
|
def _offload_hint(*, operation: str, loaded_count: int, best_latency_ms: float | None) -> dict:
|
|
if loaded_count <= 0:
|
|
suitability = "cold_only"
|
|
elif best_latency_ms is not None and best_latency_ms <= 1500:
|
|
suitability = "good_for_low_complexity"
|
|
elif best_latency_ms is not None and best_latency_ms <= 4000:
|
|
suitability = "usable_for_background_tasks"
|
|
else:
|
|
suitability = "available_but_slow"
|
|
return {
|
|
"operation": operation,
|
|
"suitability": suitability,
|
|
"recommended_for": "lower-complexity offload" if operation == "chat" else f"{operation} offload",
|
|
"inference_basis": {
|
|
"loaded_target_count": loaded_count,
|
|
"best_p50_latency_ms": best_latency_ms,
|
|
},
|
|
}
|
|
|
|
def _service_metadata(self, service: dict, *, requested_model: str | None = None) -> dict:
|
|
lat = service["observed"].get("p50_latency_ms")
|
|
loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
|
|
effective_requested_model = requested_model or service["service_id"]
|
|
return {
|
|
"route_type": "service",
|
|
"service_id": service["service_id"],
|
|
"host_id": service["host_id"],
|
|
"operation": service["kind"],
|
|
"protocol": service["protocol"],
|
|
"endpoint": service["endpoint"],
|
|
"health": service["state"].get("health"),
|
|
"loaded_asset_count": loaded_count,
|
|
"assets": service["assets"],
|
|
"runtime": service["runtime"],
|
|
"observed": service["observed"],
|
|
"effective_request_policy": self._effective_request_policy(
|
|
requested_model=effective_requested_model,
|
|
service=service,
|
|
),
|
|
"offload_hint": self._offload_hint(
|
|
operation=service["kind"],
|
|
loaded_count=loaded_count,
|
|
best_latency_ms=lat,
|
|
),
|
|
}
|
|
|
|
@staticmethod
|
|
def _effective_request_policy(
|
|
*,
|
|
requested_model: str,
|
|
service: dict | None,
|
|
role: dict | None = None,
|
|
) -> dict | None:
|
|
if service is None or service.get("kind") != "chat":
|
|
return None
|
|
asset = select_target_asset(service, requested_model)
|
|
return effective_chat_request_policy(
|
|
requested_model=requested_model,
|
|
service=service,
|
|
role=role,
|
|
asset=asset,
|
|
)
|
|
|
|
@staticmethod
|
|
def _host_row_to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"host_id": row["host_id"],
|
|
"display_name": row["display_name"],
|
|
"address": row["address"],
|
|
"labels": json.loads(row["labels_json"]),
|
|
"capabilities": json.loads(row["capabilities_json"]),
|
|
"resources": json.loads(row["resources_json"]),
|
|
"status": {
|
|
"state": row["status_state"],
|
|
"last_seen": row["last_seen"],
|
|
},
|
|
"metrics": json.loads(row["metrics_json"]),
|
|
}
|
|
|
|
@staticmethod
|
|
def _service_row_to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"service_id": row["service_id"],
|
|
"host_id": row["host_id"],
|
|
"kind": row["kind"],
|
|
"protocol": row["protocol"],
|
|
"endpoint": row["endpoint"],
|
|
"runtime": json.loads(row["runtime_json"]),
|
|
"assets": json.loads(row["assets_json"]),
|
|
"state": json.loads(row["state_json"]),
|
|
"observed": json.loads(row["observed_json"]),
|
|
"updated_at": row["updated_at"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _role_row_to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"role_id": row["role_id"],
|
|
"display_name": row["display_name"],
|
|
"description": row["description"],
|
|
"operation": row["operation"],
|
|
"modality": row["modality"],
|
|
"prompt_policy": json.loads(row["prompt_policy_json"]),
|
|
"routing_policy": json.loads(row["routing_policy_json"]),
|
|
"updated_at": row["updated_at"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _benchmark_row_to_dict(row: sqlite3.Row) -> dict:
|
|
return {
|
|
"benchmark_id": row["benchmark_id"],
|
|
"service_id": row["service_id"],
|
|
"asset_id": row["asset_id"],
|
|
"workload": row["workload"],
|
|
"observed_at": row["observed_at"],
|
|
"results": json.loads(row["results_json"]),
|
|
}
|
|
|
|
|
|
def _tokenize_text(value: str) -> set[str]:
|
|
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}
|
|
|
|
|
|
def _tokenize_tasks(tasks: list[str]) -> set[str]:
|
|
return _tokenize_text(" ".join(tasks))
|
|
|
|
|
|
def _overlap_score(task_tokens: set[str], candidate_tokens: set[str]) -> float:
|
|
if not task_tokens or not candidate_tokens:
|
|
return 0.0
|
|
overlap = len(task_tokens & candidate_tokens) / max(1, len(task_tokens))
|
|
return min(1.0, overlap)
|
|
|
|
|
|
def _benchmark_quality_score(results: dict) -> float:
|
|
if not results:
|
|
return 0.0
|
|
quality = 0.0
|
|
tokens_per_sec = results.get("tokens_per_sec")
|
|
ttft_ms = results.get("ttft_ms")
|
|
pass_rate = results.get("pass_rate")
|
|
quality_score = results.get("quality_score")
|
|
if isinstance(quality_score, (int, float)):
|
|
quality = max(quality, max(0.0, min(1.0, float(quality_score))))
|
|
if isinstance(pass_rate, (int, float)):
|
|
quality = max(quality, max(0.0, min(1.0, float(pass_rate))))
|
|
if isinstance(tokens_per_sec, (int, float)):
|
|
quality += min(0.35, float(tokens_per_sec) / 100.0)
|
|
if isinstance(ttft_ms, (int, float)):
|
|
if float(ttft_ms) <= 1000:
|
|
quality += 0.25
|
|
elif float(ttft_ms) <= 2500:
|
|
quality += 0.15
|
|
else:
|
|
quality += 0.05
|
|
return min(1.0, quality)
|