Add benchmarked route matching and request shaping

This commit is contained in:
welberr 2026-04-07 14:45:32 -04:00
parent b9270df3e8
commit e36650a017
15 changed files with 1532 additions and 7 deletions

View File

@ -53,6 +53,21 @@ make smoke
make health make health
``` ```
Benchmark workflow:
```bash
PYTHONPATH=src python scripts/run_benchmark_workload.py \
--base-url http://127.0.0.1:8800 \
--api-key change-me-client-key \
--model general_assistant \
--workload chat.short_reasoning \
--output /tmp/geniehive-bench.json
PYTHONPATH=src python scripts/ingest_benchmark_report.py /tmp/geniehive-bench.json \
--base-url http://127.0.0.1:8800 \
--api-key change-me-client-key
```
Repository conventions: Repository conventions:
- local runtime state lives under `state/` and should not be committed - local runtime state lives under `state/` and should not be committed

View File

@ -45,6 +45,10 @@ service:
assets: assets:
- asset_id: "qwen3-8b-q4km" - asset_id: "qwen3-8b-q4km"
loaded: true loaded: true
request_policy:
body_defaults:
chat_template_kwargs:
enable_thinking: false
state: state:
health: "healthy" health: "healthy"
load_state: "loaded" load_state: "loaded"
@ -84,6 +88,9 @@ role:
prompt_policy: prompt_policy:
system_prompt: "You guide without doing the user's work for them." system_prompt: "You guide without doing the user's work for them."
user_template: "{{ user_input }}" user_template: "{{ user_input }}"
request_policy:
body_defaults:
temperature: 0.2
routing_policy: routing_policy:
preferred_families: ["Qwen3", "Mistral"] preferred_families: ["Qwen3", "Mistral"]
preferred_labels: ["instruction", "stable"] preferred_labels: ["instruction", "stable"]
@ -92,6 +99,35 @@ role:
fallback_roles: ["general_assistant"] fallback_roles: ["general_assistant"]
``` ```
## Request Shape Policy
This is a general representation for model- or route-specific request shaping.
```yaml
request_shape_policy:
body_defaults:
chat_template_kwargs:
enable_thinking: false
temperature: 0.2
system_prompt: "Return only visible final answer text."
system_prompt_position: "prepend"
```
Use it for:
- model-specific request flags such as `chat_template_kwargs.enable_thinking`
- default OpenAI-compatible body fields that should be applied unless the caller already set them
- model-specific prompt instructions that should be prepended, appended, or replace an existing system message
GenieHive currently supports this policy on:
- `service.assets[].request_policy`
- `role.prompt_policy.request_policy`
The control plane may also infer built-in request policies from model family metadata. For example, Qwen3/Qwen3.5 chat routes default to `chat_template_kwargs.enable_thinking: false` unless the caller explicitly sets a different value.
`GET /v1/models` exposes the merged result as `geniehive.effective_request_policy` on service, asset, and role-backed model entries so clients can discover what GenieHive will apply by default.
## Health Sample ## Health Sample
```yaml ```yaml
@ -126,3 +162,109 @@ benchmark_sample:
ttft_ms: 780 ttft_ms: 780
tokens_per_sec: 44 tokens_per_sec: 44
``` ```
## Route Match Request
```yaml
route_match_request:
task: "fast technical reasoning for an interactive assistant"
tasks:
- "interactive debugging help"
- "concise technical explanations"
workload: "chat.short_reasoning"
workloads:
- "chat.short_reasoning"
- "chat.concise_support"
kind: "chat"
modality: "text"
include_direct_services: true
limit: 5
```
This request is meant to answer:
- which role-backed route is the best current fit for this task or task suite
- which direct services also look suitable right now
V1 matching is metadata- and runtime-driven. It uses:
- role text and routing policy overlap
- service asset and runtime metadata overlap
- loaded state
- observed latency
- observed throughput
- current queue depth when available
- recent benchmark sample workload overlap and empirical quality/performance hints
If benchmark samples exist for a candidate service, workload hints such as `chat.short_reasoning` can boost routes with recent empirical fit.
## Route Match Candidate
```yaml
route_match_candidate:
candidate_type: "role"
candidate_id: "general_assistant"
operation: "chat"
score: 0.86
reasons:
- "task text overlaps role description or policy"
- "resolved service matches role preferred model family"
- "service already has a loaded asset"
- "low observed latency"
- "good observed throughput"
signals:
task_overlap: 0.33
preferred_family_match: 1.0
loaded: true
p50_latency_ms: 1100
tokens_per_sec: 28
queue_depth: 0
benchmark_match_count: 2
best_workload_overlap: 1.0
benchmark_quality_score: 0.9
role:
role_id: "general_assistant"
service:
service_id: "p40-box/chat/gpu1-secondary"
```
## Benchmark Ingest Request
```yaml
benchmark_ingest_request:
samples:
- benchmark_id: "bench-qwen-1"
service_id: "p40-box/chat/gpu1-secondary"
asset_id: "Qwen3.5-9B-Q5_K_M"
workload: "chat.short_reasoning"
observed_at: 1775582000.0
results:
ttft_ms: 900
tokens_per_sec: 30
quality_score: 0.9
```
## Benchmark Report File
This is a file-oriented format meant for repeatable benchmark runs before ingestion into GenieHive.
```yaml
benchmark_report:
report_id: "p40-short-reasoning"
observed_at: 1775583000.0
source: "local-smoke"
samples:
- service_id: "p40-box/chat/gpu1-secondary"
asset_id: "Qwen3.5-9B-Q5_K_M"
workload: "chat.short_reasoning"
results:
ttft_ms: 900
tokens_per_sec: 30
quality_score: 0.9
```
Notes:
- `observed_at` may be set once at the report level or per sample
- `benchmark_id` is optional in the file format; GenieHive tooling can generate a stable ID during conversion
- the helper script `scripts/ingest_benchmark_report.py` loads this format and posts the expanded samples to `POST /v1/cluster/benchmarks`

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import httpx
from geniehive_control.benchmarks import load_benchmark_report, report_to_samples
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Load a benchmark report JSON file and ingest its samples into GenieHive.")
parser.add_argument("input", help="Path to benchmark report JSON")
parser.add_argument(
"--base-url",
default=os.environ.get("GENIEHIVE_CONTROL_BASE_URL", "http://127.0.0.1:8800"),
help="GenieHive control base URL",
)
parser.add_argument(
"--api-key",
default=os.environ.get("GENIEHIVE_CLIENT_API_KEY", "change-me-client-key"),
help="GenieHive client API key",
)
return parser
def main() -> int:
args = build_parser().parse_args()
report = load_benchmark_report(args.input)
samples = report_to_samples(report)
payload = {"samples": [sample.model_dump() for sample in samples]}
response = httpx.post(
args.base_url.rstrip("/") + "/v1/cluster/benchmarks",
json=payload,
headers={"X-Api-Key": args.api_key},
timeout=30.0,
)
response.raise_for_status()
print(json.dumps(response.json(), indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,44 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
from pathlib import Path
from geniehive_control.benchmark_runner import built_in_chat_workloads, run_chat_benchmark
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run a built-in chat benchmark workload against a GenieHive route or model.")
parser.add_argument("--base-url", default="http://127.0.0.1:8800", help="GenieHive control base URL")
parser.add_argument("--api-key", default="change-me-client-key", help="GenieHive client API key")
parser.add_argument("--model", required=True, help="Role, service, or direct asset/model id to benchmark")
parser.add_argument(
"--workload",
required=True,
choices=sorted(built_in_chat_workloads().keys()),
help="Built-in benchmark workload to run",
)
parser.add_argument("--output", help="Optional path to write the benchmark report JSON")
return parser
def main() -> int:
args = build_parser().parse_args()
workload = built_in_chat_workloads()[args.workload]
report = run_chat_benchmark(
base_url=args.base_url,
api_key=args.api_key,
model=args.model,
workload=workload,
)
rendered = json.dumps(report.model_dump(), indent=2, sort_keys=True)
if args.output:
Path(args.output).write_text(rendered + "\n", encoding="utf-8")
else:
print(rendered)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@ -0,0 +1,172 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any, Callable
import httpx
from .benchmarks import BenchmarkReport, BenchmarkReportSample
@dataclass(slots=True)
class ChatBenchmarkCase:
name: str
prompt: str
max_completion_tokens: int = 120
@dataclass(slots=True)
class ChatBenchmarkWorkload:
workload: str
system_prompt: str
cases: list[ChatBenchmarkCase]
chat_template_kwargs: dict[str, Any] | None = None
def built_in_chat_workloads() -> dict[str, ChatBenchmarkWorkload]:
return {
"chat.short_reasoning": ChatBenchmarkWorkload(
workload="chat.short_reasoning",
system_prompt=(
"You are a concise and careful reasoning assistant. "
"Return only a visible final answer. "
"Do not spend tokens on hidden reasoning or an internal monologue."
),
chat_template_kwargs={"enable_thinking": False},
cases=[
ChatBenchmarkCase(
name="short_reasoning_1",
prompt="In two short paragraphs, explain why a loaded healthy route should be preferred over a cold route. Return only the final answer text.",
),
ChatBenchmarkCase(
name="short_reasoning_2",
prompt="Give a compact tradeoff summary for a service with lower latency but worse throughput than another. Return only the final answer text.",
),
],
),
"chat.concise_support": ChatBenchmarkWorkload(
workload="chat.concise_support",
system_prompt="You are a concise support assistant. Return only a visible final answer.",
cases=[
ChatBenchmarkCase(
name="concise_support_1",
prompt="Reply with a short troubleshooting checklist for a local API endpoint returning 404. Return only the final answer text.",
),
ChatBenchmarkCase(
name="concise_support_2",
prompt="Reply with a short checklist for checking a tmux-managed service that exited at startup. Return only the final answer text.",
),
],
),
}
def run_chat_benchmark(
*,
base_url: str,
api_key: str,
model: str,
workload: ChatBenchmarkWorkload,
request_fn: Callable[[str, dict[str, str], dict[str, Any]], dict[str, Any]] | None = None,
observed_at: float | None = None,
) -> BenchmarkReport:
request = request_fn or _default_chat_request
latencies_ms: list[float] = []
ttfts_ms: list[float] = []
tokens_per_sec_values: list[float] = []
completion_tokens: list[int] = []
prompt_tokens: list[int] = []
passed = 0
responses_received = 0
empty_visible_responses = 0
for case in workload.cases:
start = time.perf_counter()
request_payload = {
"model": model,
"messages": [
{"role": "system", "content": workload.system_prompt},
{"role": "user", "content": case.prompt},
],
"max_tokens": case.max_completion_tokens,
}
if workload.chat_template_kwargs:
request_payload["chat_template_kwargs"] = dict(workload.chat_template_kwargs)
payload = request(
base_url.rstrip("/") + "/v1/chat/completions",
{
"X-Api-Key": api_key,
"Content-Type": "application/json",
},
request_payload,
)
elapsed_ms = (time.perf_counter() - start) * 1000.0
latencies_ms.append(elapsed_ms)
usage = payload.get("usage", {})
timings = payload.get("timings", {})
prompt_tokens.append(int(usage.get("prompt_tokens", 0) or 0))
completion_tokens.append(int(usage.get("completion_tokens", 0) or 0))
responses_received += 1
if isinstance(timings.get("prompt_ms"), (int, float)):
ttfts_ms.append(float(timings["prompt_ms"]))
if isinstance(timings.get("predicted_per_second"), (int, float)):
tokens_per_sec_values.append(float(timings["predicted_per_second"]))
elif isinstance(timings.get("predicted_ms"), (int, float)) and completion_tokens[-1] > 0 and float(timings["predicted_ms"]) > 0:
tokens_per_sec_values.append((completion_tokens[-1] * 1000.0) / float(timings["predicted_ms"]))
if _has_nonempty_content(payload):
passed += 1
else:
empty_visible_responses += 1
sample = BenchmarkReportSample(
service_id=model,
workload=workload.workload,
observed_at=observed_at or time.time(),
results={
"case_count": len(workload.cases),
"pass_rate": passed / max(1, len(workload.cases)),
"response_rate": responses_received / max(1, len(workload.cases)),
"empty_visible_response_rate": empty_visible_responses / max(1, len(workload.cases)),
"p50_latency_ms": _median(latencies_ms),
"ttft_ms": _median(ttfts_ms) if ttfts_ms else None,
"tokens_per_sec": _median(tokens_per_sec_values) if tokens_per_sec_values else None,
"prompt_tokens": int(sum(prompt_tokens) / max(1, len(prompt_tokens))),
"completion_tokens": int(sum(completion_tokens) / max(1, len(completion_tokens))),
},
)
return BenchmarkReport(
report_id=f"{model}-{workload.workload}",
observed_at=sample.observed_at,
source="geniehive-benchmark-runner",
samples=[sample],
)
def _default_chat_request(url: str, headers: dict[str, str], payload: dict[str, Any]) -> dict[str, Any]:
with httpx.Client(timeout=120.0) as client:
response = client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
def _median(values: list[float]) -> float | None:
if not values:
return None
ordered = sorted(values)
middle = len(ordered) // 2
if len(ordered) % 2:
return ordered[middle]
return (ordered[middle - 1] + ordered[middle]) / 2.0
def _has_nonempty_content(payload: dict[str, Any]) -> bool:
choices = payload.get("choices", [])
if not choices:
return False
message = choices[0].get("message", {})
if str(message.get("content", "")).strip():
return True
return bool(str(message.get("reasoning_content", "")).strip())

View File

@ -0,0 +1,69 @@
from __future__ import annotations
import hashlib
import json
from pathlib import Path
from pydantic import BaseModel, Field
from .models import BenchmarkSample
class BenchmarkReportSample(BaseModel):
service_id: str
asset_id: str | None = None
workload: str
observed_at: float | None = None
benchmark_id: str | None = None
results: dict[str, object] = Field(default_factory=dict)
class BenchmarkReport(BaseModel):
report_id: str | None = None
observed_at: float | None = None
source: str | None = None
samples: list[BenchmarkReportSample] = Field(default_factory=list)
def load_benchmark_report(path: str | Path) -> BenchmarkReport:
raw = json.loads(Path(path).read_text(encoding="utf-8"))
return BenchmarkReport.model_validate(raw)
def report_to_samples(report: BenchmarkReport) -> list[BenchmarkSample]:
samples: list[BenchmarkSample] = []
for index, sample in enumerate(report.samples):
observed_at = sample.observed_at if sample.observed_at is not None else report.observed_at
if observed_at is None:
raise ValueError("Benchmark report sample is missing observed_at and report has no default observed_at.")
benchmark_id = sample.benchmark_id or _make_benchmark_id(report.report_id, sample, index, observed_at)
payload = dict(sample.results)
if report.source and "source" not in payload:
payload["source"] = report.source
samples.append(
BenchmarkSample(
benchmark_id=benchmark_id,
service_id=sample.service_id,
asset_id=sample.asset_id,
workload=sample.workload,
observed_at=observed_at,
results=payload,
)
)
return samples
def _make_benchmark_id(report_id: str | None, sample: BenchmarkReportSample, index: int, observed_at: float) -> str:
digest = hashlib.sha1(
"|".join(
[
report_id or "report",
sample.service_id,
sample.asset_id or "",
sample.workload,
str(observed_at),
str(index),
]
).encode("utf-8")
).hexdigest()[:16]
return f"bench_{digest}"

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from .request_policy import apply_request_policy, effective_chat_request_policy, select_target_asset
from .registry import Registry from .registry import Registry
from .routing import choose_upstream_model_id from .routing import choose_upstream_model_id
from .upstream import UpstreamClient from .upstream import UpstreamClient
@ -26,7 +27,6 @@ def _strip_reasoning_fields(payload: Any) -> Any:
cleaned[key] = _strip_reasoning_fields(value) cleaned[key] = _strip_reasoning_fields(value)
return cleaned return cleaned
async def proxy_chat_completion( async def proxy_chat_completion(
body: dict[str, Any], body: dict[str, Any],
*, *,
@ -45,7 +45,16 @@ async def proxy_chat_completion(
if service is None: if service is None:
raise ProxyError(f"No healthy chat target available for '{requested_model}'.", status_code=503) raise ProxyError(f"No healthy chat target available for '{requested_model}'.", status_code=503)
upstream_body = dict(body) asset = select_target_asset(service, requested_model)
role = resolved.get("role")
combined_policy = effective_chat_request_policy(
requested_model=requested_model,
service=service,
role=role,
asset=asset,
)
upstream_body = apply_request_policy(dict(body), combined_policy)
upstream_body["model"] = choose_upstream_model_id(requested_model, service) upstream_body["model"] = choose_upstream_model_id(requested_model, service)
response = await upstream.chat_completions(service["endpoint"], upstream_body) response = await upstream.chat_completions(service["endpoint"], upstream_body)
return _strip_reasoning_fields(response) return _strip_reasoning_fields(response)

View File

@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse
from .auth import require_client_auth, require_node_auth from .auth import require_client_auth, require_node_auth
from .chat import ProxyError, proxy_chat_completion, proxy_embeddings from .chat import ProxyError, proxy_chat_completion, proxy_embeddings
from .config import ControlConfig, load_config from .config import ControlConfig, load_config
from .models import HostHeartbeat, HostRegistration from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, RouteMatchRequest, RouteMatchResponse
from .roles import load_role_catalog from .roles import load_role_catalog
from .registry import Registry from .registry import Registry
from .upstream import UpstreamClient, UpstreamError from .upstream import UpstreamClient, UpstreamError
@ -105,6 +105,23 @@ def create_app(
async def list_services(request: Request, _=Depends(require_client_auth)) -> dict: async def list_services(request: Request, _=Depends(require_client_auth)) -> dict:
return {"object": "list", "data": request.app.state.registry.list_services()} return {"object": "list", "data": request.app.state.registry.list_services()}
@app.get("/v1/cluster/benchmarks")
async def list_benchmarks(
request: Request,
service_id: str | None = None,
workload: str | None = None,
_=Depends(require_client_auth),
) -> dict:
return {
"object": "list",
"data": request.app.state.registry.list_benchmark_samples(service_id=service_id, workload=workload),
}
@app.post("/v1/cluster/benchmarks")
async def ingest_benchmarks(payload: BenchmarkIngestRequest, request: Request, _=Depends(require_client_auth)) -> dict:
samples = request.app.state.registry.upsert_benchmark_samples(payload.samples)
return {"status": "ok", "count": len(payload.samples), "data": samples}
@app.get("/v1/cluster/roles") @app.get("/v1/cluster/roles")
async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict: async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict:
return {"object": "list", "data": request.app.state.registry.list_roles()} return {"object": "list", "data": request.app.state.registry.list_roles()}
@ -121,6 +138,11 @@ def create_app(
return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind}) return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind})
return {"status": "ok", "resolution": resolved} return {"status": "ok", "resolution": resolved}
@app.post("/v1/cluster/routes/match")
async def match_routes(payload: RouteMatchRequest, request: Request, _=Depends(require_client_auth)) -> dict:
response = request.app.state.registry.match_routes(payload)
return RouteMatchResponse.model_validate(response).model_dump()
return app return app

View File

@ -5,9 +5,16 @@ from typing import Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class RequestShapePolicy(BaseModel):
body_defaults: dict[str, Any] = Field(default_factory=dict)
system_prompt: str | None = None
system_prompt_position: Literal["prepend", "append", "replace"] = "prepend"
class ServiceAsset(BaseModel): class ServiceAsset(BaseModel):
asset_id: str asset_id: str
loaded: bool = False loaded: bool = False
request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy)
class ServiceRuntime(BaseModel): class ServiceRuntime(BaseModel):
@ -66,6 +73,7 @@ class HostHeartbeat(BaseModel):
class PromptPolicy(BaseModel): class PromptPolicy(BaseModel):
system_prompt: str | None = None system_prompt: str | None = None
user_template: str | None = None user_template: str | None = None
request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy)
class RoutingPolicy(BaseModel): class RoutingPolicy(BaseModel):
@ -88,3 +96,48 @@ class RoleProfile(BaseModel):
class RoleCatalog(BaseModel): class RoleCatalog(BaseModel):
roles: list[RoleProfile] = Field(default_factory=list) roles: list[RoleProfile] = Field(default_factory=list)
class BenchmarkSample(BaseModel):
benchmark_id: str
service_id: str
asset_id: str | None = None
workload: str
observed_at: float
results: dict[str, Any] = Field(default_factory=dict)
class BenchmarkIngestRequest(BaseModel):
samples: list[BenchmarkSample] = Field(default_factory=list)
class RouteMatchRequest(BaseModel):
task: str | None = None
tasks: list[str] = Field(default_factory=list)
workload: str | None = None
workloads: list[str] = Field(default_factory=list)
kind: Literal["chat", "embeddings", "transcription"] | None = None
modality: str | None = None
include_direct_services: bool = True
limit: int = 10
class RouteMatchCandidate(BaseModel):
candidate_type: Literal["role", "service"]
candidate_id: str
operation: Literal["chat", "embeddings", "transcription"]
score: float
reasons: list[str] = Field(default_factory=list)
signals: dict[str, Any] = Field(default_factory=dict)
role: dict[str, Any] | None = None
service: dict[str, Any] | None = None
class RouteMatchResponse(BaseModel):
status: str = "ok"
task_count: int
tasks: list[str]
workloads: list[str] = Field(default_factory=list)
kind: Literal["chat", "embeddings", "transcription"] | None = None
modality: str | None = None
candidates: list[RouteMatchCandidate] = Field(default_factory=list)

View File

@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
import sqlite3 import sqlite3
import time import time
from pathlib import Path from pathlib import Path
from .models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile from .models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest
from .request_policy import effective_chat_request_policy, select_target_asset
def _json_dumps(value: object) -> str: def _json_dumps(value: object) -> str:
@ -63,6 +65,15 @@ class Registry:
routing_policy_json TEXT NOT NULL, routing_policy_json TEXT NOT NULL,
updated_at REAL NOT NULL updated_at REAL NOT NULL
); );
CREATE TABLE IF NOT EXISTS benchmark_samples (
benchmark_id TEXT PRIMARY KEY,
service_id TEXT NOT NULL,
asset_id TEXT,
workload TEXT NOT NULL,
observed_at REAL NOT NULL,
results_json TEXT NOT NULL
);
""" """
) )
@ -217,6 +228,50 @@ class Registry:
rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall() rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall()
return [self._service_row_to_dict(row) for row in rows] return [self._service_row_to_dict(row) for row in rows]
def upsert_benchmark_samples(self, samples: list[BenchmarkSample]) -> list[dict]:
with self._connect() as conn:
for sample in samples:
conn.execute(
"""
INSERT INTO benchmark_samples (
benchmark_id, service_id, asset_id, workload, observed_at, results_json
)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(benchmark_id) DO UPDATE SET
service_id=excluded.service_id,
asset_id=excluded.asset_id,
workload=excluded.workload,
observed_at=excluded.observed_at,
results_json=excluded.results_json
""",
(
sample.benchmark_id,
sample.service_id,
sample.asset_id,
sample.workload,
sample.observed_at,
_json_dumps(sample.results),
),
)
return self.list_benchmark_samples()
def list_benchmark_samples(self, *, service_id: str | None = None, workload: str | None = None) -> list[dict]:
query = "SELECT * FROM benchmark_samples"
clauses = []
params: list[object] = []
if service_id:
clauses.append("service_id = ?")
params.append(service_id)
if workload:
clauses.append("workload = ?")
params.append(workload)
if clauses:
query += " WHERE " + " AND ".join(clauses)
query += " ORDER BY observed_at DESC, benchmark_id"
with self._connect() as conn:
rows = conn.execute(query, params).fetchall()
return [self._benchmark_row_to_dict(row) for row in rows]
def list_client_models(self) -> list[dict]: def list_client_models(self) -> list[dict]:
services = self.list_services() services = self.list_services()
roles = self.list_roles() roles = self.list_roles()
@ -243,7 +298,7 @@ class Registry:
"id": asset_id, "id": asset_id,
"object": "model", "object": "model",
"owned_by": service["host_id"], "owned_by": service["host_id"],
"geniehive": self._service_metadata(service) | {"route_type": "asset", "asset_id": asset_id}, "geniehive": self._service_metadata(service, requested_model=asset_id) | {"route_type": "asset", "asset_id": asset_id},
} }
) )
@ -282,6 +337,11 @@ class Registry:
best_latency_ms=best_latency_ms, best_latency_ms=best_latency_ms,
), ),
"routing_policy": role["routing_policy"], "routing_policy": role["routing_policy"],
"effective_request_policy": self._effective_request_policy(
requested_model=role["role_id"],
role=role,
service=self.resolve_route(role["role_id"], kind=role["operation"]).get("service") if matching_services else None,
),
}, },
} }
) )
@ -331,6 +391,54 @@ class Registry:
service = max(candidates, key=score) service = max(candidates, key=score)
return {"match_type": "role", "role": role, "service": service} return {"match_type": "role", "role": role, "service": service}
def match_routes(self, request: RouteMatchRequest) -> dict:
tasks = [task.strip() for task in ([request.task] if request.task else []) + request.tasks if task and task.strip()]
workloads = [value.strip() for value in ([request.workload] if request.workload else []) + request.workloads if value and value.strip()]
kind = request.kind
modality = request.modality
services = [
service
for service in self.list_services()
if (kind is None or service["kind"] == kind)
and service["state"].get("accept_requests", True)
and service["state"].get("health") == "healthy"
]
roles = [
role
for role in self.list_roles()
if (kind is None or role["operation"] == kind)
and (modality is None or role["modality"] == modality)
]
candidates: list[dict] = []
for role in roles:
resolved = self.resolve_route(role["role_id"], kind=role["operation"])
service = resolved["service"] if resolved is not None else None
candidate = self._score_role_candidate(role, service, tasks, workloads)
candidates.append(candidate)
if request.include_direct_services:
for service in services:
candidates.append(self._score_service_candidate(service, tasks, workloads))
candidates.sort(
key=lambda item: (
-item["score"],
item["candidate_type"] != "role",
item["candidate_id"],
)
)
limit = max(1, request.limit)
return {
"status": "ok",
"task_count": len(tasks),
"tasks": tasks,
"workloads": workloads,
"kind": kind,
"modality": modality,
"candidates": candidates[:limit],
}
def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None: def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None:
candidates = [] candidates = []
for service in self.list_services(): for service in self.list_services():
@ -355,6 +463,163 @@ class Registry:
service = max(candidates, key=score) service = max(candidates, key=score)
return {"service": service} return {"service": service}
def _score_role_candidate(self, role: dict, service: dict | None, tasks: list[str], workloads: list[str]) -> dict:
task_tokens = _tokenize_tasks(tasks)
text_parts = [
role.get("role_id", ""),
role.get("display_name", "") or "",
role.get("description", "") or "",
(role.get("prompt_policy") or {}).get("system_prompt", "") or "",
" ".join((role.get("routing_policy") or {}).get("preferred_families", [])),
" ".join((role.get("routing_policy") or {}).get("preferred_labels", [])),
]
text_score = _overlap_score(task_tokens, _tokenize_text(" ".join(text_parts)))
preferred_families = [family.lower() for family in (role.get("routing_policy") or {}).get("preferred_families", [])]
family_score = 0.0
if service is not None and preferred_families:
asset_names = " ".join(asset.get("asset_id", "") for asset in service["assets"]).lower()
if any(family in asset_names for family in preferred_families):
family_score = 1.0
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
score = min(1.0, 0.30 * text_score + 0.15 * family_score + 0.30 * runtime_score + 0.25 * benchmark_score)
reasons = []
if text_score > 0:
reasons.append("task text overlaps role description or policy")
if family_score > 0:
reasons.append("resolved service matches role preferred model family")
reasons.extend(runtime_reasons)
reasons.extend(benchmark_reasons)
if service is None:
reasons.append("no healthy service currently resolves for this role")
return {
"candidate_type": "role",
"candidate_id": role["role_id"],
"operation": role["operation"],
"score": round(score, 4),
"reasons": reasons,
"signals": {
"task_overlap": round(text_score, 4),
"preferred_family_match": family_score,
**runtime_signals,
**benchmark_signals,
},
"role": role,
"service": service,
}
def _score_service_candidate(self, service: dict, tasks: list[str], workloads: list[str]) -> dict:
task_tokens = _tokenize_tasks(tasks)
service_text = " ".join(
[
service.get("service_id", ""),
service.get("host_id", ""),
" ".join(asset.get("asset_id", "") for asset in service.get("assets", [])),
" ".join(f"{key} {value}" for key, value in (service.get("runtime") or {}).items() if value),
]
)
text_score = _overlap_score(task_tokens, _tokenize_text(service_text))
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
score = min(1.0, 0.20 * text_score + 0.45 * runtime_score + 0.35 * benchmark_score)
reasons = []
if text_score > 0:
reasons.append("task text overlaps service or asset metadata")
reasons.extend(runtime_reasons)
reasons.extend(benchmark_reasons)
return {
"candidate_type": "service",
"candidate_id": service["service_id"],
"operation": service["kind"],
"score": round(score, 4),
"reasons": reasons,
"signals": {
"task_overlap": round(text_score, 4),
**runtime_signals,
**benchmark_signals,
},
"role": None,
"service": service,
}
@staticmethod
def _runtime_signals(service: dict | None) -> tuple[float, list[str], dict[str, object]]:
if service is None:
return 0.0, [], {"loaded": False, "p50_latency_ms": None, "tokens_per_sec": None}
loaded = any(asset.get("loaded") for asset in service.get("assets", []))
latency = service["observed"].get("p50_latency_ms")
tokens_per_sec = service["observed"].get("tokens_per_sec")
queue_depth = service["observed"].get("queue_depth")
score = 0.0
reasons: list[str] = []
if loaded:
score += 0.35
reasons.append("service already has a loaded asset")
if latency is not None:
if latency <= 1500:
score += 0.30
reasons.append("low observed latency")
elif latency <= 4000:
score += 0.18
reasons.append("moderate observed latency")
else:
score += 0.05
reasons.append("high but usable latency")
if tokens_per_sec is not None:
if tokens_per_sec >= 20:
score += 0.20
reasons.append("good observed throughput")
elif tokens_per_sec >= 8:
score += 0.10
reasons.append("usable observed throughput")
if queue_depth is not None and queue_depth > 0:
score -= min(0.15, 0.03 * queue_depth)
reasons.append("current queue depth reduces suitability")
return max(0.0, min(1.0, score)), reasons, {
"loaded": loaded,
"p50_latency_ms": latency,
"tokens_per_sec": tokens_per_sec,
"queue_depth": queue_depth,
}
def _benchmark_signals(self, service: dict | None, tasks: list[str], workloads: list[str]) -> tuple[float, list[str], dict[str, object]]:
if service is None:
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
samples = self.list_benchmark_samples(service_id=service["service_id"])
if not samples:
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
query_tokens = _tokenize_tasks(tasks + workloads)
best_overlap = 0.0
best_quality = 0.0
matched_count = 0
for sample in samples:
workload_tokens = _tokenize_text(sample["workload"])
overlap = _overlap_score(query_tokens, workload_tokens) if query_tokens else 0.0
if workloads and sample["workload"] in workloads:
overlap = max(overlap, 1.0)
quality = _benchmark_quality_score(sample["results"])
if overlap > 0 or not query_tokens:
matched_count += 1
best_overlap = max(best_overlap, overlap)
best_quality = max(best_quality, quality)
score = 0.55 * best_overlap + 0.45 * best_quality if matched_count else 0.0
reasons: list[str] = []
if matched_count:
reasons.append("recent benchmark sample matches requested workload or task shape")
if best_quality >= 0.6:
reasons.append("benchmark results indicate strong empirical fit")
elif matched_count:
reasons.append("benchmark results indicate limited but relevant empirical fit")
return min(1.0, score), reasons, {
"benchmark_match_count": matched_count,
"best_workload_overlap": round(best_overlap, 4),
"benchmark_quality_score": round(best_quality, 4) if matched_count else None,
}
def cluster_health(self, stale_after_s: float) -> dict: def cluster_health(self, stale_after_s: float) -> dict:
hosts = self.list_hosts() hosts = self.list_hosts()
services = self.list_services() services = self.list_services()
@ -397,9 +662,10 @@ class Registry:
}, },
} }
def _service_metadata(self, service: dict) -> dict: def _service_metadata(self, service: dict, *, requested_model: str | None = None) -> dict:
lat = service["observed"].get("p50_latency_ms") lat = service["observed"].get("p50_latency_ms")
loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0 loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
effective_requested_model = requested_model or service["service_id"]
return { return {
"route_type": "service", "route_type": "service",
"service_id": service["service_id"], "service_id": service["service_id"],
@ -412,6 +678,10 @@ class Registry:
"assets": service["assets"], "assets": service["assets"],
"runtime": service["runtime"], "runtime": service["runtime"],
"observed": service["observed"], "observed": service["observed"],
"effective_request_policy": self._effective_request_policy(
requested_model=effective_requested_model,
service=service,
),
"offload_hint": self._offload_hint( "offload_hint": self._offload_hint(
operation=service["kind"], operation=service["kind"],
loaded_count=loaded_count, loaded_count=loaded_count,
@ -419,6 +689,23 @@ class Registry:
), ),
} }
@staticmethod
def _effective_request_policy(
*,
requested_model: str,
service: dict | None,
role: dict | None = None,
) -> dict | None:
if service is None or service.get("kind") != "chat":
return None
asset = select_target_asset(service, requested_model)
return effective_chat_request_policy(
requested_model=requested_model,
service=service,
role=role,
asset=asset,
)
@staticmethod @staticmethod
def _host_row_to_dict(row: sqlite3.Row) -> dict: def _host_row_to_dict(row: sqlite3.Row) -> dict:
return { return {
@ -462,3 +749,53 @@ class Registry:
"routing_policy": json.loads(row["routing_policy_json"]), "routing_policy": json.loads(row["routing_policy_json"]),
"updated_at": row["updated_at"], "updated_at": row["updated_at"],
} }
@staticmethod
def _benchmark_row_to_dict(row: sqlite3.Row) -> dict:
return {
"benchmark_id": row["benchmark_id"],
"service_id": row["service_id"],
"asset_id": row["asset_id"],
"workload": row["workload"],
"observed_at": row["observed_at"],
"results": json.loads(row["results_json"]),
}
def _tokenize_text(value: str) -> set[str]:
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}
def _tokenize_tasks(tasks: list[str]) -> set[str]:
return _tokenize_text(" ".join(tasks))
def _overlap_score(task_tokens: set[str], candidate_tokens: set[str]) -> float:
if not task_tokens or not candidate_tokens:
return 0.0
overlap = len(task_tokens & candidate_tokens) / max(1, len(task_tokens))
return min(1.0, overlap)
def _benchmark_quality_score(results: dict) -> float:
if not results:
return 0.0
quality = 0.0
tokens_per_sec = results.get("tokens_per_sec")
ttft_ms = results.get("ttft_ms")
pass_rate = results.get("pass_rate")
quality_score = results.get("quality_score")
if isinstance(quality_score, (int, float)):
quality = max(quality, max(0.0, min(1.0, float(quality_score))))
if isinstance(pass_rate, (int, float)):
quality = max(quality, max(0.0, min(1.0, float(pass_rate))))
if isinstance(tokens_per_sec, (int, float)):
quality += min(0.35, float(tokens_per_sec) / 100.0)
if isinstance(ttft_ms, (int, float)):
if float(ttft_ms) <= 1000:
quality += 0.25
elif float(ttft_ms) <= 2500:
quality += 0.15
else:
quality += 0.05
return min(1.0, quality)

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import copy
import re
from typing import Any
def deep_merge_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
merged = copy.deepcopy(payload)
for key, value in defaults.items():
if key not in merged:
merged[key] = copy.deepcopy(value)
continue
if isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = deep_merge_defaults(merged[key], value)
return merged
def deep_merge_override(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
merged = copy.deepcopy(base)
for key, value in override.items():
if isinstance(merged.get(key), dict) and isinstance(value, dict):
merged[key] = deep_merge_override(merged[key], value)
else:
merged[key] = copy.deepcopy(value)
return merged
def apply_system_prompt(messages: list[dict[str, Any]], prompt: str, position: str) -> list[dict[str, Any]]:
if any(message.get("role") == "system" and message.get("content") == prompt for message in messages):
return messages
injected = {"role": "system", "content": prompt}
if position == "append":
return [*messages, injected]
if position == "replace":
updated = list(messages)
for index, message in enumerate(updated):
if message.get("role") == "system":
updated[index] = injected
return updated
return [injected, *updated]
return [injected, *messages]
def apply_request_policy(body: dict[str, Any], policy: dict[str, Any] | None) -> dict[str, Any]:
if not policy:
return dict(body)
shaped = deep_merge_defaults(body, policy.get("body_defaults", {}) or {})
system_prompt = policy.get("system_prompt")
if not system_prompt:
return shaped
messages = shaped.get("messages")
if not isinstance(messages, list):
return shaped
normalized_messages = [message for message in messages if isinstance(message, dict)]
shaped["messages"] = apply_system_prompt(
normalized_messages,
system_prompt,
str(policy.get("system_prompt_position") or "prepend"),
)
return shaped
def merge_request_policies(*policies: dict[str, Any] | None) -> dict[str, Any] | None:
combined: dict[str, Any] = {}
for policy in policies:
if not policy:
continue
body_defaults = policy.get("body_defaults") or {}
if body_defaults:
combined["body_defaults"] = deep_merge_override(combined.get("body_defaults", {}), body_defaults)
if policy.get("system_prompt"):
combined["system_prompt"] = policy["system_prompt"]
combined["system_prompt_position"] = policy.get("system_prompt_position", "prepend")
return combined or None
def select_target_asset(service: dict[str, Any], requested_model: str) -> dict[str, Any] | None:
assets = service.get("assets", [])
for asset in assets:
if asset.get("asset_id") == requested_model:
return asset
for asset in assets:
if asset.get("loaded") and asset.get("asset_id"):
return asset
for asset in assets:
if asset.get("asset_id"):
return asset
return None
def looks_like_qwen3(value: str) -> bool:
normalized = re.sub(r"[^a-z0-9]+", "", value.lower())
return normalized.startswith("qwen3")
def infer_chat_request_policy(requested_model: str, service: dict[str, Any], asset: dict[str, Any] | None) -> dict[str, Any] | None:
identifiers = [requested_model, service.get("service_id", "")]
if asset is not None:
identifiers.append(asset.get("asset_id", ""))
identifiers.extend(asset_item.get("asset_id", "") for asset_item in service.get("assets", []))
if any(looks_like_qwen3(identifier) for identifier in identifiers if identifier):
return {
"body_defaults": {
"chat_template_kwargs": {
"enable_thinking": False,
}
}
}
return None
def effective_chat_request_policy(
*,
requested_model: str,
service: dict[str, Any],
role: dict[str, Any] | None = None,
asset: dict[str, Any] | None = None,
) -> dict[str, Any] | None:
return merge_request_policies(
infer_chat_request_policy(requested_model, service, asset),
(asset or {}).get("request_policy"),
((role or {}).get("prompt_policy") or {}).get("request_policy"),
)

View File

@ -0,0 +1,104 @@
from geniehive_control.benchmark_runner import ChatBenchmarkCase, ChatBenchmarkWorkload, built_in_chat_workloads, run_chat_benchmark
def test_built_in_chat_workloads_exist() -> None:
workloads = built_in_chat_workloads()
assert "chat.short_reasoning" in workloads
assert workloads["chat.short_reasoning"].cases
assert workloads["chat.short_reasoning"].chat_template_kwargs == {"enable_thinking": False}
def test_run_chat_benchmark_generates_report() -> None:
workload = ChatBenchmarkWorkload(
workload="chat.short_reasoning",
system_prompt="You are concise.",
cases=[
ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly."),
ChatBenchmarkCase(name="case2", prompt="Explain latency versus throughput briefly."),
],
)
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
assert url.endswith("/v1/chat/completions")
assert payload["model"] == "general_assistant"
assert "chat_template_kwargs" not in payload
return {
"choices": [{"message": {"role": "assistant", "content": "Benchmark response."}}],
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
}
report = run_chat_benchmark(
base_url="http://127.0.0.1:8800",
api_key="change-me-client-key",
model="general_assistant",
workload=workload,
request_fn=fake_request,
observed_at=1775584000.0,
)
sample = report.samples[0]
assert report.source == "geniehive-benchmark-runner"
assert sample.workload == "chat.short_reasoning"
assert sample.results["case_count"] == 2
assert sample.results["pass_rate"] == 1.0
assert sample.results["response_rate"] == 1.0
assert sample.results["empty_visible_response_rate"] == 0.0
assert sample.results["tokens_per_sec"] == 25.0
assert sample.observed_at == 1775584000.0
def test_run_chat_benchmark_treats_reasoning_content_as_a_pass() -> None:
workload = ChatBenchmarkWorkload(
workload="chat.short_reasoning",
system_prompt="You are concise.",
cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")],
)
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
return {
"choices": [{"message": {"role": "assistant", "content": "", "reasoning_content": "Reasoning only."}}],
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
}
report = run_chat_benchmark(
base_url="http://127.0.0.1:8800",
api_key="change-me-client-key",
model="general_assistant",
workload=workload,
request_fn=fake_request,
observed_at=1775584000.0,
)
assert report.samples[0].results["pass_rate"] == 1.0
assert report.samples[0].results["empty_visible_response_rate"] == 0.0
def test_run_chat_benchmark_includes_chat_template_kwargs_when_configured() -> None:
workload = ChatBenchmarkWorkload(
workload="chat.short_reasoning",
system_prompt="You are concise.",
chat_template_kwargs={"enable_thinking": False},
cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")],
)
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
assert payload["chat_template_kwargs"] == {"enable_thinking": False}
return {
"choices": [{"message": {"role": "assistant", "content": "Visible response."}}],
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
}
report = run_chat_benchmark(
base_url="http://127.0.0.1:8800",
api_key="change-me-client-key",
model="general_assistant",
workload=workload,
request_fn=fake_request,
observed_at=1775584000.0,
)
assert report.samples[0].results["pass_rate"] == 1.0

66
tests/test_benchmarks.py Normal file
View File

@ -0,0 +1,66 @@
from pathlib import Path
from geniehive_control.benchmarks import load_benchmark_report, report_to_samples
def test_load_benchmark_report_and_generate_sample_ids(tmp_path: Path) -> None:
report_path = tmp_path / "bench.json"
report_path.write_text(
"""
{
"report_id": "p40-short-reasoning",
"observed_at": 1775583000.0,
"source": "local-smoke",
"samples": [
{
"service_id": "p40-box/chat/gpu1-secondary",
"asset_id": "Qwen3.5-9B-Q5_K_M",
"workload": "chat.short_reasoning",
"results": {
"ttft_ms": 900,
"tokens_per_sec": 30,
"quality_score": 0.9
}
}
]
}
""".strip(),
encoding="utf-8",
)
report = load_benchmark_report(report_path)
samples = report_to_samples(report)
assert report.report_id == "p40-short-reasoning"
assert len(samples) == 1
assert samples[0].benchmark_id.startswith("bench_")
assert samples[0].observed_at == 1775583000.0
assert samples[0].results["source"] == "local-smoke"
def test_report_to_samples_preserves_explicit_sample_ids(tmp_path: Path) -> None:
report_path = tmp_path / "bench.json"
report_path.write_text(
"""
{
"observed_at": 1775583000.0,
"samples": [
{
"benchmark_id": "bench-explicit",
"service_id": "p40-box/chat/gpu0-primary",
"workload": "chat.concise_support",
"results": {
"tokens_per_sec": 24
}
}
]
}
""".strip(),
encoding="utf-8",
)
report = load_benchmark_report(report_path)
samples = report_to_samples(report)
assert samples[0].benchmark_id == "bench-explicit"
assert samples[0].workload == "chat.concise_support"

View File

@ -158,6 +158,108 @@ def test_proxy_chat_completion_strips_reasoning_fields(tmp_path: Path) -> None:
assert "reasoning" not in choice assert "reasoning" not in choice
def test_proxy_chat_completion_applies_inferred_qwen_request_defaults(tmp_path: Path) -> None:
registry = _build_registry(tmp_path)
class _InspectingPoster:
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
assert json["chat_template_kwargs"] == {"enable_thinking": False}
return _FakeResponse({"ok": True, "echo_model": json["model"]})
upstream = UpstreamClient(client=_InspectingPoster())
async def run() -> dict:
return await proxy_chat_completion(
{
"model": "mentor",
"messages": [{"role": "user", "content": "hello"}],
},
registry=registry,
upstream=upstream,
)
result = asyncio.run(run())
assert result["echo_model"] == "qwen3-8b-q4km"
def test_proxy_chat_completion_preserves_explicit_template_kwargs(tmp_path: Path) -> None:
registry = _build_registry(tmp_path)
class _InspectingPoster:
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
assert json["chat_template_kwargs"] == {"enable_thinking": True, "foo": "bar"}
return _FakeResponse({"ok": True, "echo_model": json["model"]})
upstream = UpstreamClient(client=_InspectingPoster())
async def run() -> dict:
return await proxy_chat_completion(
{
"model": "mentor",
"messages": [{"role": "user", "content": "hello"}],
"chat_template_kwargs": {"enable_thinking": True, "foo": "bar"},
},
registry=registry,
upstream=upstream,
)
result = asyncio.run(run())
assert result["echo_model"] == "qwen3-8b-q4km"
def test_proxy_chat_completion_applies_asset_request_policy(tmp_path: Path) -> None:
registry = Registry(tmp_path / "geniehive.sqlite3")
registry.register_host(
HostRegistration(
host_id="atlas-01",
address="192.168.1.101",
services=[
RegisteredService(
service_id="atlas-01/chat/custom-model",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18091",
assets=[
{
"asset_id": "custom-model-v1",
"loaded": True,
"request_policy": {
"body_defaults": {
"temperature": 0.2,
"chat_template_kwargs": {"custom_flag": "yes"},
}
},
}
],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 900},
)
],
)
)
class _InspectingPoster:
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
assert json["temperature"] == 0.2
assert json["chat_template_kwargs"] == {"custom_flag": "yes"}
return _FakeResponse({"ok": True, "echo_model": json["model"]})
upstream = UpstreamClient(client=_InspectingPoster())
async def run() -> dict:
return await proxy_chat_completion(
{
"model": "custom-model-v1",
"messages": [{"role": "user", "content": "hello"}],
},
registry=registry,
upstream=upstream,
)
result = asyncio.run(run())
assert result["echo_model"] == "custom-model-v1"
def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None: def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None:
registry = _build_registry(tmp_path) registry = _build_registry(tmp_path)
upstream = UpstreamClient(client=_FakePoster()) upstream = UpstreamClient(client=_FakePoster())

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from geniehive_control.main import create_app from geniehive_control.main import create_app
from geniehive_control.models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile from geniehive_control.models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest
from geniehive_control.registry import Registry from geniehive_control.registry import Registry
@ -133,9 +133,11 @@ def test_registry_persists_roles_and_resolves_direct_and_role_routes(tmp_path: P
mentor = next(item for item in models if item["id"] == "mentor") mentor = next(item for item in models if item["id"] == "mentor")
assert mentor["geniehive"]["route_type"] == "role" assert mentor["geniehive"]["route_type"] == "role"
assert mentor["geniehive"]["offload_hint"]["suitability"] == "good_for_low_complexity" assert mentor["geniehive"]["offload_hint"]["suitability"] == "good_for_low_complexity"
assert mentor["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False
asset = next(item for item in models if item["id"] == "qwen3-8b-q4km") asset = next(item for item in models if item["id"] == "qwen3-8b-q4km")
assert asset["geniehive"]["route_type"] == "asset" assert asset["geniehive"]["route_type"] == "asset"
assert asset["geniehive"]["offload_hint"]["recommended_for"] == "lower-complexity offload" assert asset["geniehive"]["offload_hint"]["recommended_for"] == "lower-complexity offload"
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False
def test_control_app_exposes_expected_routes() -> None: def test_control_app_exposes_expected_routes() -> None:
@ -147,6 +149,220 @@ def test_control_app_exposes_expected_routes() -> None:
assert "/v1/nodes/heartbeat" in paths assert "/v1/nodes/heartbeat" in paths
assert "/v1/cluster/hosts" in paths assert "/v1/cluster/hosts" in paths
assert "/v1/cluster/services" in paths assert "/v1/cluster/services" in paths
assert "/v1/cluster/benchmarks" in paths
assert "/v1/cluster/roles" in paths assert "/v1/cluster/roles" in paths
assert "/v1/cluster/health" in paths assert "/v1/cluster/health" in paths
assert "/v1/cluster/routes/resolve" in paths assert "/v1/cluster/routes/resolve" in paths
assert "/v1/cluster/routes/match" in paths
def test_registry_can_rank_routes_for_task_statements(tmp_path: Path) -> None:
db_path = tmp_path / "geniehive.sqlite3"
registry = Registry(db_path)
registry.register_host(
HostRegistration(
host_id="atlas-01",
address="192.168.1.101",
services=[
RegisteredService(
service_id="atlas-01/chat/qwen-reasoner",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18091",
assets=[{"asset_id": "qwen3.5-9b", "loaded": True}],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 1100, "tokens_per_sec": 28},
),
RegisteredService(
service_id="atlas-01/chat/rocket-background",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18093",
assets=[{"asset_id": "rocket-3b", "loaded": True}],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 4200, "tokens_per_sec": 7},
),
],
)
)
registry.upsert_roles(
[
RoleProfile(
role_id="general_assistant",
display_name="General Assistant",
description="Fast general technical assistant for reasoning and question answering.",
operation="chat",
modality="text",
routing_policy={"preferred_families": ["qwen3.5", "qwen"]},
),
RoleProfile(
role_id="background_summarizer",
display_name="Background Summarizer",
description="Slow fallback summarizer for lower-priority background work.",
operation="chat",
modality="text",
routing_policy={"preferred_families": ["rocket"]},
),
]
)
result = registry.match_routes(
RouteMatchRequest(
task="fast technical reasoning for an interactive assistant",
kind="chat",
modality="text",
limit=4,
)
)
assert result["task_count"] == 1
assert result["candidates"]
top = result["candidates"][0]
assert top["candidate_type"] == "role"
assert top["candidate_id"] == "general_assistant"
assert top["service"]["service_id"] == "atlas-01/chat/qwen-reasoner"
assert top["signals"]["preferred_family_match"] == 1.0
def test_registry_match_can_include_direct_services(tmp_path: Path) -> None:
db_path = tmp_path / "geniehive.sqlite3"
registry = Registry(db_path)
registry.register_host(
HostRegistration(
host_id="atlas-01",
address="192.168.1.101",
services=[
RegisteredService(
service_id="atlas-01/chat/qwen3-8b",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18091",
assets=[{"asset_id": "qwen3-8b-q4km", "loaded": True}],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 900, "tokens_per_sec": 40},
)
],
)
)
result = registry.match_routes(
RouteMatchRequest(
task="qwen model for quick chat",
kind="chat",
include_direct_services=True,
limit=4,
)
)
direct = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service")
assert direct["candidate_id"] == "atlas-01/chat/qwen3-8b"
assert direct["service"]["assets"][0]["asset_id"] == "qwen3-8b-q4km"
def test_registry_persists_benchmark_samples_and_uses_them_for_matching(tmp_path: Path) -> None:
db_path = tmp_path / "geniehive.sqlite3"
registry = Registry(db_path)
registry.register_host(
HostRegistration(
host_id="atlas-01",
address="192.168.1.101",
services=[
RegisteredService(
service_id="atlas-01/chat/qwen-fast",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18091",
assets=[{"asset_id": "qwen3.5-9b", "loaded": True}],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 1500, "tokens_per_sec": 22},
),
RegisteredService(
service_id="atlas-01/chat/rocket-slow",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18093",
assets=[{"asset_id": "rocket-3b", "loaded": True}],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 1200, "tokens_per_sec": 10},
),
],
)
)
registry.upsert_benchmark_samples(
[
BenchmarkSample(
benchmark_id="bench-qwen-1",
service_id="atlas-01/chat/qwen-fast",
asset_id="qwen3.5-9b",
workload="chat.short_reasoning",
observed_at=1000.0,
results={"tokens_per_sec": 30, "ttft_ms": 900, "quality_score": 0.9},
),
BenchmarkSample(
benchmark_id="bench-rocket-1",
service_id="atlas-01/chat/rocket-slow",
asset_id="rocket-3b",
workload="chat.short_reasoning",
observed_at=1000.0,
results={"tokens_per_sec": 9, "ttft_ms": 1900, "quality_score": 0.4},
),
]
)
samples = registry.list_benchmark_samples(service_id="atlas-01/chat/qwen-fast")
assert len(samples) == 1
assert samples[0]["benchmark_id"] == "bench-qwen-1"
result = registry.match_routes(
RouteMatchRequest(
task="fast short reasoning for chat responses",
workloads=["chat.short_reasoning"],
kind="chat",
include_direct_services=True,
limit=4,
)
)
top_service = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service")
assert top_service["candidate_id"] == "atlas-01/chat/qwen-fast"
assert top_service["signals"]["benchmark_match_count"] == 1
assert top_service["signals"]["best_workload_overlap"] == 1.0
def test_registry_exposes_asset_request_policy_in_model_metadata(tmp_path: Path) -> None:
db_path = tmp_path / "geniehive.sqlite3"
registry = Registry(db_path)
registry.register_host(
HostRegistration(
host_id="atlas-01",
address="192.168.1.101",
services=[
RegisteredService(
service_id="atlas-01/chat/custom-model",
host_id="atlas-01",
kind="chat",
endpoint="http://192.168.1.101:18091",
assets=[
{
"asset_id": "custom-model-v1",
"loaded": True,
"request_policy": {
"body_defaults": {
"temperature": 0.2,
"chat_template_kwargs": {"custom_flag": "yes"},
}
},
}
],
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
observed={"p50_latency_ms": 900},
)
],
)
)
models = registry.list_client_models()
asset = next(item for item in models if item["id"] == "custom-model-v1")
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["temperature"] == 0.2
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["custom_flag"] == "yes"