Add benchmarked route matching and request shaping
This commit is contained in:
parent
b9270df3e8
commit
e36650a017
15
README.md
15
README.md
|
|
@ -53,6 +53,21 @@ make smoke
|
||||||
make health
|
make health
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Benchmark workflow:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=src python scripts/run_benchmark_workload.py \
|
||||||
|
--base-url http://127.0.0.1:8800 \
|
||||||
|
--api-key change-me-client-key \
|
||||||
|
--model general_assistant \
|
||||||
|
--workload chat.short_reasoning \
|
||||||
|
--output /tmp/geniehive-bench.json
|
||||||
|
|
||||||
|
PYTHONPATH=src python scripts/ingest_benchmark_report.py /tmp/geniehive-bench.json \
|
||||||
|
--base-url http://127.0.0.1:8800 \
|
||||||
|
--api-key change-me-client-key
|
||||||
|
```
|
||||||
|
|
||||||
Repository conventions:
|
Repository conventions:
|
||||||
|
|
||||||
- local runtime state lives under `state/` and should not be committed
|
- local runtime state lives under `state/` and should not be committed
|
||||||
|
|
|
||||||
142
docs/schemas.md
142
docs/schemas.md
|
|
@ -45,6 +45,10 @@ service:
|
||||||
assets:
|
assets:
|
||||||
- asset_id: "qwen3-8b-q4km"
|
- asset_id: "qwen3-8b-q4km"
|
||||||
loaded: true
|
loaded: true
|
||||||
|
request_policy:
|
||||||
|
body_defaults:
|
||||||
|
chat_template_kwargs:
|
||||||
|
enable_thinking: false
|
||||||
state:
|
state:
|
||||||
health: "healthy"
|
health: "healthy"
|
||||||
load_state: "loaded"
|
load_state: "loaded"
|
||||||
|
|
@ -84,6 +88,9 @@ role:
|
||||||
prompt_policy:
|
prompt_policy:
|
||||||
system_prompt: "You guide without doing the user's work for them."
|
system_prompt: "You guide without doing the user's work for them."
|
||||||
user_template: "{{ user_input }}"
|
user_template: "{{ user_input }}"
|
||||||
|
request_policy:
|
||||||
|
body_defaults:
|
||||||
|
temperature: 0.2
|
||||||
routing_policy:
|
routing_policy:
|
||||||
preferred_families: ["Qwen3", "Mistral"]
|
preferred_families: ["Qwen3", "Mistral"]
|
||||||
preferred_labels: ["instruction", "stable"]
|
preferred_labels: ["instruction", "stable"]
|
||||||
|
|
@ -92,6 +99,35 @@ role:
|
||||||
fallback_roles: ["general_assistant"]
|
fallback_roles: ["general_assistant"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Request Shape Policy
|
||||||
|
|
||||||
|
This is a general representation for model- or route-specific request shaping.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
request_shape_policy:
|
||||||
|
body_defaults:
|
||||||
|
chat_template_kwargs:
|
||||||
|
enable_thinking: false
|
||||||
|
temperature: 0.2
|
||||||
|
system_prompt: "Return only visible final answer text."
|
||||||
|
system_prompt_position: "prepend"
|
||||||
|
```
|
||||||
|
|
||||||
|
Use it for:
|
||||||
|
|
||||||
|
- model-specific request flags such as `chat_template_kwargs.enable_thinking`
|
||||||
|
- default OpenAI-compatible body fields that should be applied unless the caller already set them
|
||||||
|
- model-specific prompt instructions that should be prepended, appended, or replace an existing system message
|
||||||
|
|
||||||
|
GenieHive currently supports this policy on:
|
||||||
|
|
||||||
|
- `service.assets[].request_policy`
|
||||||
|
- `role.prompt_policy.request_policy`
|
||||||
|
|
||||||
|
The control plane may also infer built-in request policies from model family metadata. For example, Qwen3/Qwen3.5 chat routes default to `chat_template_kwargs.enable_thinking: false` unless the caller explicitly sets a different value.
|
||||||
|
|
||||||
|
`GET /v1/models` exposes the merged result as `geniehive.effective_request_policy` on service, asset, and role-backed model entries so clients can discover what GenieHive will apply by default.
|
||||||
|
|
||||||
## Health Sample
|
## Health Sample
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
@ -126,3 +162,109 @@ benchmark_sample:
|
||||||
ttft_ms: 780
|
ttft_ms: 780
|
||||||
tokens_per_sec: 44
|
tokens_per_sec: 44
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Route Match Request
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
route_match_request:
|
||||||
|
task: "fast technical reasoning for an interactive assistant"
|
||||||
|
tasks:
|
||||||
|
- "interactive debugging help"
|
||||||
|
- "concise technical explanations"
|
||||||
|
workload: "chat.short_reasoning"
|
||||||
|
workloads:
|
||||||
|
- "chat.short_reasoning"
|
||||||
|
- "chat.concise_support"
|
||||||
|
kind: "chat"
|
||||||
|
modality: "text"
|
||||||
|
include_direct_services: true
|
||||||
|
limit: 5
|
||||||
|
```
|
||||||
|
|
||||||
|
This request is meant to answer:
|
||||||
|
|
||||||
|
- which role-backed route is the best current fit for this task or task suite
|
||||||
|
- which direct services also look suitable right now
|
||||||
|
|
||||||
|
V1 matching is metadata- and runtime-driven. It uses:
|
||||||
|
|
||||||
|
- role text and routing policy overlap
|
||||||
|
- service asset and runtime metadata overlap
|
||||||
|
- loaded state
|
||||||
|
- observed latency
|
||||||
|
- observed throughput
|
||||||
|
- current queue depth when available
|
||||||
|
- recent benchmark sample workload overlap and empirical quality/performance hints
|
||||||
|
|
||||||
|
If benchmark samples exist for a candidate service, workload hints such as `chat.short_reasoning` can boost routes with recent empirical fit.
|
||||||
|
|
||||||
|
## Route Match Candidate
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
route_match_candidate:
|
||||||
|
candidate_type: "role"
|
||||||
|
candidate_id: "general_assistant"
|
||||||
|
operation: "chat"
|
||||||
|
score: 0.86
|
||||||
|
reasons:
|
||||||
|
- "task text overlaps role description or policy"
|
||||||
|
- "resolved service matches role preferred model family"
|
||||||
|
- "service already has a loaded asset"
|
||||||
|
- "low observed latency"
|
||||||
|
- "good observed throughput"
|
||||||
|
signals:
|
||||||
|
task_overlap: 0.33
|
||||||
|
preferred_family_match: 1.0
|
||||||
|
loaded: true
|
||||||
|
p50_latency_ms: 1100
|
||||||
|
tokens_per_sec: 28
|
||||||
|
queue_depth: 0
|
||||||
|
benchmark_match_count: 2
|
||||||
|
best_workload_overlap: 1.0
|
||||||
|
benchmark_quality_score: 0.9
|
||||||
|
role:
|
||||||
|
role_id: "general_assistant"
|
||||||
|
service:
|
||||||
|
service_id: "p40-box/chat/gpu1-secondary"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark Ingest Request
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
benchmark_ingest_request:
|
||||||
|
samples:
|
||||||
|
- benchmark_id: "bench-qwen-1"
|
||||||
|
service_id: "p40-box/chat/gpu1-secondary"
|
||||||
|
asset_id: "Qwen3.5-9B-Q5_K_M"
|
||||||
|
workload: "chat.short_reasoning"
|
||||||
|
observed_at: 1775582000.0
|
||||||
|
results:
|
||||||
|
ttft_ms: 900
|
||||||
|
tokens_per_sec: 30
|
||||||
|
quality_score: 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
## Benchmark Report File
|
||||||
|
|
||||||
|
This is a file-oriented format meant for repeatable benchmark runs before ingestion into GenieHive.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
benchmark_report:
|
||||||
|
report_id: "p40-short-reasoning"
|
||||||
|
observed_at: 1775583000.0
|
||||||
|
source: "local-smoke"
|
||||||
|
samples:
|
||||||
|
- service_id: "p40-box/chat/gpu1-secondary"
|
||||||
|
asset_id: "Qwen3.5-9B-Q5_K_M"
|
||||||
|
workload: "chat.short_reasoning"
|
||||||
|
results:
|
||||||
|
ttft_ms: 900
|
||||||
|
tokens_per_sec: 30
|
||||||
|
quality_score: 0.9
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- `observed_at` may be set once at the report level or per sample
|
||||||
|
- `benchmark_id` is optional in the file format; GenieHive tooling can generate a stable ID during conversion
|
||||||
|
- the helper script `scripts/ingest_benchmark_report.py` loads this format and posts the expanded samples to `POST /v1/cluster/benchmarks`
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from geniehive_control.benchmarks import load_benchmark_report, report_to_samples
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description="Load a benchmark report JSON file and ingest its samples into GenieHive.")
|
||||||
|
parser.add_argument("input", help="Path to benchmark report JSON")
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
default=os.environ.get("GENIEHIVE_CONTROL_BASE_URL", "http://127.0.0.1:8800"),
|
||||||
|
help="GenieHive control base URL",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
default=os.environ.get("GENIEHIVE_CLIENT_API_KEY", "change-me-client-key"),
|
||||||
|
help="GenieHive client API key",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
report = load_benchmark_report(args.input)
|
||||||
|
samples = report_to_samples(report)
|
||||||
|
payload = {"samples": [sample.model_dump() for sample in samples]}
|
||||||
|
response = httpx.post(
|
||||||
|
args.base_url.rstrip("/") + "/v1/cluster/benchmarks",
|
||||||
|
json=payload,
|
||||||
|
headers={"X-Api-Key": args.api_key},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
print(json.dumps(response.json(), indent=2, sort_keys=True))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from geniehive_control.benchmark_runner import built_in_chat_workloads, run_chat_benchmark
|
||||||
|
|
||||||
|
|
||||||
|
def build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(description="Run a built-in chat benchmark workload against a GenieHive route or model.")
|
||||||
|
parser.add_argument("--base-url", default="http://127.0.0.1:8800", help="GenieHive control base URL")
|
||||||
|
parser.add_argument("--api-key", default="change-me-client-key", help="GenieHive client API key")
|
||||||
|
parser.add_argument("--model", required=True, help="Role, service, or direct asset/model id to benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--workload",
|
||||||
|
required=True,
|
||||||
|
choices=sorted(built_in_chat_workloads().keys()),
|
||||||
|
help="Built-in benchmark workload to run",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Optional path to write the benchmark report JSON")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = build_parser().parse_args()
|
||||||
|
workload = built_in_chat_workloads()[args.workload]
|
||||||
|
report = run_chat_benchmark(
|
||||||
|
base_url=args.base_url,
|
||||||
|
api_key=args.api_key,
|
||||||
|
model=args.model,
|
||||||
|
workload=workload,
|
||||||
|
)
|
||||||
|
rendered = json.dumps(report.model_dump(), indent=2, sort_keys=True)
|
||||||
|
if args.output:
|
||||||
|
Path(args.output).write_text(rendered + "\n", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
print(rendered)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
|
|
@ -0,0 +1,172 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .benchmarks import BenchmarkReport, BenchmarkReportSample
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChatBenchmarkCase:
|
||||||
|
name: str
|
||||||
|
prompt: str
|
||||||
|
max_completion_tokens: int = 120
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChatBenchmarkWorkload:
|
||||||
|
workload: str
|
||||||
|
system_prompt: str
|
||||||
|
cases: list[ChatBenchmarkCase]
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def built_in_chat_workloads() -> dict[str, ChatBenchmarkWorkload]:
|
||||||
|
return {
|
||||||
|
"chat.short_reasoning": ChatBenchmarkWorkload(
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
system_prompt=(
|
||||||
|
"You are a concise and careful reasoning assistant. "
|
||||||
|
"Return only a visible final answer. "
|
||||||
|
"Do not spend tokens on hidden reasoning or an internal monologue."
|
||||||
|
),
|
||||||
|
chat_template_kwargs={"enable_thinking": False},
|
||||||
|
cases=[
|
||||||
|
ChatBenchmarkCase(
|
||||||
|
name="short_reasoning_1",
|
||||||
|
prompt="In two short paragraphs, explain why a loaded healthy route should be preferred over a cold route. Return only the final answer text.",
|
||||||
|
),
|
||||||
|
ChatBenchmarkCase(
|
||||||
|
name="short_reasoning_2",
|
||||||
|
prompt="Give a compact tradeoff summary for a service with lower latency but worse throughput than another. Return only the final answer text.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"chat.concise_support": ChatBenchmarkWorkload(
|
||||||
|
workload="chat.concise_support",
|
||||||
|
system_prompt="You are a concise support assistant. Return only a visible final answer.",
|
||||||
|
cases=[
|
||||||
|
ChatBenchmarkCase(
|
||||||
|
name="concise_support_1",
|
||||||
|
prompt="Reply with a short troubleshooting checklist for a local API endpoint returning 404. Return only the final answer text.",
|
||||||
|
),
|
||||||
|
ChatBenchmarkCase(
|
||||||
|
name="concise_support_2",
|
||||||
|
prompt="Reply with a short checklist for checking a tmux-managed service that exited at startup. Return only the final answer text.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat_benchmark(
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
api_key: str,
|
||||||
|
model: str,
|
||||||
|
workload: ChatBenchmarkWorkload,
|
||||||
|
request_fn: Callable[[str, dict[str, str], dict[str, Any]], dict[str, Any]] | None = None,
|
||||||
|
observed_at: float | None = None,
|
||||||
|
) -> BenchmarkReport:
|
||||||
|
request = request_fn or _default_chat_request
|
||||||
|
latencies_ms: list[float] = []
|
||||||
|
ttfts_ms: list[float] = []
|
||||||
|
tokens_per_sec_values: list[float] = []
|
||||||
|
completion_tokens: list[int] = []
|
||||||
|
prompt_tokens: list[int] = []
|
||||||
|
passed = 0
|
||||||
|
responses_received = 0
|
||||||
|
empty_visible_responses = 0
|
||||||
|
|
||||||
|
for case in workload.cases:
|
||||||
|
start = time.perf_counter()
|
||||||
|
request_payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": workload.system_prompt},
|
||||||
|
{"role": "user", "content": case.prompt},
|
||||||
|
],
|
||||||
|
"max_tokens": case.max_completion_tokens,
|
||||||
|
}
|
||||||
|
if workload.chat_template_kwargs:
|
||||||
|
request_payload["chat_template_kwargs"] = dict(workload.chat_template_kwargs)
|
||||||
|
|
||||||
|
payload = request(
|
||||||
|
base_url.rstrip("/") + "/v1/chat/completions",
|
||||||
|
{
|
||||||
|
"X-Api-Key": api_key,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
request_payload,
|
||||||
|
)
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000.0
|
||||||
|
latencies_ms.append(elapsed_ms)
|
||||||
|
|
||||||
|
usage = payload.get("usage", {})
|
||||||
|
timings = payload.get("timings", {})
|
||||||
|
prompt_tokens.append(int(usage.get("prompt_tokens", 0) or 0))
|
||||||
|
completion_tokens.append(int(usage.get("completion_tokens", 0) or 0))
|
||||||
|
responses_received += 1
|
||||||
|
if isinstance(timings.get("prompt_ms"), (int, float)):
|
||||||
|
ttfts_ms.append(float(timings["prompt_ms"]))
|
||||||
|
if isinstance(timings.get("predicted_per_second"), (int, float)):
|
||||||
|
tokens_per_sec_values.append(float(timings["predicted_per_second"]))
|
||||||
|
elif isinstance(timings.get("predicted_ms"), (int, float)) and completion_tokens[-1] > 0 and float(timings["predicted_ms"]) > 0:
|
||||||
|
tokens_per_sec_values.append((completion_tokens[-1] * 1000.0) / float(timings["predicted_ms"]))
|
||||||
|
if _has_nonempty_content(payload):
|
||||||
|
passed += 1
|
||||||
|
else:
|
||||||
|
empty_visible_responses += 1
|
||||||
|
|
||||||
|
sample = BenchmarkReportSample(
|
||||||
|
service_id=model,
|
||||||
|
workload=workload.workload,
|
||||||
|
observed_at=observed_at or time.time(),
|
||||||
|
results={
|
||||||
|
"case_count": len(workload.cases),
|
||||||
|
"pass_rate": passed / max(1, len(workload.cases)),
|
||||||
|
"response_rate": responses_received / max(1, len(workload.cases)),
|
||||||
|
"empty_visible_response_rate": empty_visible_responses / max(1, len(workload.cases)),
|
||||||
|
"p50_latency_ms": _median(latencies_ms),
|
||||||
|
"ttft_ms": _median(ttfts_ms) if ttfts_ms else None,
|
||||||
|
"tokens_per_sec": _median(tokens_per_sec_values) if tokens_per_sec_values else None,
|
||||||
|
"prompt_tokens": int(sum(prompt_tokens) / max(1, len(prompt_tokens))),
|
||||||
|
"completion_tokens": int(sum(completion_tokens) / max(1, len(completion_tokens))),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return BenchmarkReport(
|
||||||
|
report_id=f"{model}-{workload.workload}",
|
||||||
|
observed_at=sample.observed_at,
|
||||||
|
source="geniehive-benchmark-runner",
|
||||||
|
samples=[sample],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_chat_request(url: str, headers: dict[str, str], payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
response = client.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
|
def _median(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
ordered = sorted(values)
|
||||||
|
middle = len(ordered) // 2
|
||||||
|
if len(ordered) % 2:
|
||||||
|
return ordered[middle]
|
||||||
|
return (ordered[middle - 1] + ordered[middle]) / 2.0
|
||||||
|
|
||||||
|
|
||||||
|
def _has_nonempty_content(payload: dict[str, Any]) -> bool:
|
||||||
|
choices = payload.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return False
|
||||||
|
message = choices[0].get("message", {})
|
||||||
|
if str(message.get("content", "")).strip():
|
||||||
|
return True
|
||||||
|
return bool(str(message.get("reasoning_content", "")).strip())
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .models import BenchmarkSample
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkReportSample(BaseModel):
|
||||||
|
service_id: str
|
||||||
|
asset_id: str | None = None
|
||||||
|
workload: str
|
||||||
|
observed_at: float | None = None
|
||||||
|
benchmark_id: str | None = None
|
||||||
|
results: dict[str, object] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkReport(BaseModel):
|
||||||
|
report_id: str | None = None
|
||||||
|
observed_at: float | None = None
|
||||||
|
source: str | None = None
|
||||||
|
samples: list[BenchmarkReportSample] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def load_benchmark_report(path: str | Path) -> BenchmarkReport:
|
||||||
|
raw = json.loads(Path(path).read_text(encoding="utf-8"))
|
||||||
|
return BenchmarkReport.model_validate(raw)
|
||||||
|
|
||||||
|
|
||||||
|
def report_to_samples(report: BenchmarkReport) -> list[BenchmarkSample]:
|
||||||
|
samples: list[BenchmarkSample] = []
|
||||||
|
for index, sample in enumerate(report.samples):
|
||||||
|
observed_at = sample.observed_at if sample.observed_at is not None else report.observed_at
|
||||||
|
if observed_at is None:
|
||||||
|
raise ValueError("Benchmark report sample is missing observed_at and report has no default observed_at.")
|
||||||
|
benchmark_id = sample.benchmark_id or _make_benchmark_id(report.report_id, sample, index, observed_at)
|
||||||
|
payload = dict(sample.results)
|
||||||
|
if report.source and "source" not in payload:
|
||||||
|
payload["source"] = report.source
|
||||||
|
samples.append(
|
||||||
|
BenchmarkSample(
|
||||||
|
benchmark_id=benchmark_id,
|
||||||
|
service_id=sample.service_id,
|
||||||
|
asset_id=sample.asset_id,
|
||||||
|
workload=sample.workload,
|
||||||
|
observed_at=observed_at,
|
||||||
|
results=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def _make_benchmark_id(report_id: str | None, sample: BenchmarkReportSample, index: int, observed_at: float) -> str:
|
||||||
|
digest = hashlib.sha1(
|
||||||
|
"|".join(
|
||||||
|
[
|
||||||
|
report_id or "report",
|
||||||
|
sample.service_id,
|
||||||
|
sample.asset_id or "",
|
||||||
|
sample.workload,
|
||||||
|
str(observed_at),
|
||||||
|
str(index),
|
||||||
|
]
|
||||||
|
).encode("utf-8")
|
||||||
|
).hexdigest()[:16]
|
||||||
|
return f"bench_{digest}"
|
||||||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from .request_policy import apply_request_policy, effective_chat_request_policy, select_target_asset
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
from .routing import choose_upstream_model_id
|
from .routing import choose_upstream_model_id
|
||||||
from .upstream import UpstreamClient
|
from .upstream import UpstreamClient
|
||||||
|
|
@ -26,7 +27,6 @@ def _strip_reasoning_fields(payload: Any) -> Any:
|
||||||
cleaned[key] = _strip_reasoning_fields(value)
|
cleaned[key] = _strip_reasoning_fields(value)
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
async def proxy_chat_completion(
|
async def proxy_chat_completion(
|
||||||
body: dict[str, Any],
|
body: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
|
|
@ -45,7 +45,16 @@ async def proxy_chat_completion(
|
||||||
if service is None:
|
if service is None:
|
||||||
raise ProxyError(f"No healthy chat target available for '{requested_model}'.", status_code=503)
|
raise ProxyError(f"No healthy chat target available for '{requested_model}'.", status_code=503)
|
||||||
|
|
||||||
upstream_body = dict(body)
|
asset = select_target_asset(service, requested_model)
|
||||||
|
role = resolved.get("role")
|
||||||
|
combined_policy = effective_chat_request_policy(
|
||||||
|
requested_model=requested_model,
|
||||||
|
service=service,
|
||||||
|
role=role,
|
||||||
|
asset=asset,
|
||||||
|
)
|
||||||
|
|
||||||
|
upstream_body = apply_request_policy(dict(body), combined_policy)
|
||||||
upstream_body["model"] = choose_upstream_model_id(requested_model, service)
|
upstream_body["model"] = choose_upstream_model_id(requested_model, service)
|
||||||
response = await upstream.chat_completions(service["endpoint"], upstream_body)
|
response = await upstream.chat_completions(service["endpoint"], upstream_body)
|
||||||
return _strip_reasoning_fields(response)
|
return _strip_reasoning_fields(response)
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse
|
||||||
from .auth import require_client_auth, require_node_auth
|
from .auth import require_client_auth, require_node_auth
|
||||||
from .chat import ProxyError, proxy_chat_completion, proxy_embeddings
|
from .chat import ProxyError, proxy_chat_completion, proxy_embeddings
|
||||||
from .config import ControlConfig, load_config
|
from .config import ControlConfig, load_config
|
||||||
from .models import HostHeartbeat, HostRegistration
|
from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, RouteMatchRequest, RouteMatchResponse
|
||||||
from .roles import load_role_catalog
|
from .roles import load_role_catalog
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
from .upstream import UpstreamClient, UpstreamError
|
from .upstream import UpstreamClient, UpstreamError
|
||||||
|
|
@ -105,6 +105,23 @@ def create_app(
|
||||||
async def list_services(request: Request, _=Depends(require_client_auth)) -> dict:
|
async def list_services(request: Request, _=Depends(require_client_auth)) -> dict:
|
||||||
return {"object": "list", "data": request.app.state.registry.list_services()}
|
return {"object": "list", "data": request.app.state.registry.list_services()}
|
||||||
|
|
||||||
|
@app.get("/v1/cluster/benchmarks")
|
||||||
|
async def list_benchmarks(
|
||||||
|
request: Request,
|
||||||
|
service_id: str | None = None,
|
||||||
|
workload: str | None = None,
|
||||||
|
_=Depends(require_client_auth),
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": request.app.state.registry.list_benchmark_samples(service_id=service_id, workload=workload),
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/v1/cluster/benchmarks")
|
||||||
|
async def ingest_benchmarks(payload: BenchmarkIngestRequest, request: Request, _=Depends(require_client_auth)) -> dict:
|
||||||
|
samples = request.app.state.registry.upsert_benchmark_samples(payload.samples)
|
||||||
|
return {"status": "ok", "count": len(payload.samples), "data": samples}
|
||||||
|
|
||||||
@app.get("/v1/cluster/roles")
|
@app.get("/v1/cluster/roles")
|
||||||
async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict:
|
async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict:
|
||||||
return {"object": "list", "data": request.app.state.registry.list_roles()}
|
return {"object": "list", "data": request.app.state.registry.list_roles()}
|
||||||
|
|
@ -121,6 +138,11 @@ def create_app(
|
||||||
return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind})
|
return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind})
|
||||||
return {"status": "ok", "resolution": resolved}
|
return {"status": "ok", "resolution": resolved}
|
||||||
|
|
||||||
|
@app.post("/v1/cluster/routes/match")
|
||||||
|
async def match_routes(payload: RouteMatchRequest, request: Request, _=Depends(require_client_auth)) -> dict:
|
||||||
|
response = request.app.state.registry.match_routes(payload)
|
||||||
|
return RouteMatchResponse.model_validate(response).model_dump()
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,16 @@ from typing import Any, Literal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RequestShapePolicy(BaseModel):
|
||||||
|
body_defaults: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
system_prompt: str | None = None
|
||||||
|
system_prompt_position: Literal["prepend", "append", "replace"] = "prepend"
|
||||||
|
|
||||||
|
|
||||||
class ServiceAsset(BaseModel):
|
class ServiceAsset(BaseModel):
|
||||||
asset_id: str
|
asset_id: str
|
||||||
loaded: bool = False
|
loaded: bool = False
|
||||||
|
request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy)
|
||||||
|
|
||||||
|
|
||||||
class ServiceRuntime(BaseModel):
|
class ServiceRuntime(BaseModel):
|
||||||
|
|
@ -66,6 +73,7 @@ class HostHeartbeat(BaseModel):
|
||||||
class PromptPolicy(BaseModel):
|
class PromptPolicy(BaseModel):
|
||||||
system_prompt: str | None = None
|
system_prompt: str | None = None
|
||||||
user_template: str | None = None
|
user_template: str | None = None
|
||||||
|
request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy)
|
||||||
|
|
||||||
|
|
||||||
class RoutingPolicy(BaseModel):
|
class RoutingPolicy(BaseModel):
|
||||||
|
|
@ -88,3 +96,48 @@ class RoleProfile(BaseModel):
|
||||||
|
|
||||||
class RoleCatalog(BaseModel):
|
class RoleCatalog(BaseModel):
|
||||||
roles: list[RoleProfile] = Field(default_factory=list)
|
roles: list[RoleProfile] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkSample(BaseModel):
|
||||||
|
benchmark_id: str
|
||||||
|
service_id: str
|
||||||
|
asset_id: str | None = None
|
||||||
|
workload: str
|
||||||
|
observed_at: float
|
||||||
|
results: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkIngestRequest(BaseModel):
|
||||||
|
samples: list[BenchmarkSample] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class RouteMatchRequest(BaseModel):
|
||||||
|
task: str | None = None
|
||||||
|
tasks: list[str] = Field(default_factory=list)
|
||||||
|
workload: str | None = None
|
||||||
|
workloads: list[str] = Field(default_factory=list)
|
||||||
|
kind: Literal["chat", "embeddings", "transcription"] | None = None
|
||||||
|
modality: str | None = None
|
||||||
|
include_direct_services: bool = True
|
||||||
|
limit: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class RouteMatchCandidate(BaseModel):
|
||||||
|
candidate_type: Literal["role", "service"]
|
||||||
|
candidate_id: str
|
||||||
|
operation: Literal["chat", "embeddings", "transcription"]
|
||||||
|
score: float
|
||||||
|
reasons: list[str] = Field(default_factory=list)
|
||||||
|
signals: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
role: dict[str, Any] | None = None
|
||||||
|
service: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RouteMatchResponse(BaseModel):
|
||||||
|
status: str = "ok"
|
||||||
|
task_count: int
|
||||||
|
tasks: list[str]
|
||||||
|
workloads: list[str] = Field(default_factory=list)
|
||||||
|
kind: Literal["chat", "embeddings", "transcription"] | None = None
|
||||||
|
modality: str | None = None
|
||||||
|
candidates: list[RouteMatchCandidate] = Field(default_factory=list)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile
|
from .models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest
|
||||||
|
from .request_policy import effective_chat_request_policy, select_target_asset
|
||||||
|
|
||||||
|
|
||||||
def _json_dumps(value: object) -> str:
|
def _json_dumps(value: object) -> str:
|
||||||
|
|
@ -63,6 +65,15 @@ class Registry:
|
||||||
routing_policy_json TEXT NOT NULL,
|
routing_policy_json TEXT NOT NULL,
|
||||||
updated_at REAL NOT NULL
|
updated_at REAL NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS benchmark_samples (
|
||||||
|
benchmark_id TEXT PRIMARY KEY,
|
||||||
|
service_id TEXT NOT NULL,
|
||||||
|
asset_id TEXT,
|
||||||
|
workload TEXT NOT NULL,
|
||||||
|
observed_at REAL NOT NULL,
|
||||||
|
results_json TEXT NOT NULL
|
||||||
|
);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -217,6 +228,50 @@ class Registry:
|
||||||
rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall()
|
rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall()
|
||||||
return [self._service_row_to_dict(row) for row in rows]
|
return [self._service_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
def upsert_benchmark_samples(self, samples: list[BenchmarkSample]) -> list[dict]:
|
||||||
|
with self._connect() as conn:
|
||||||
|
for sample in samples:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO benchmark_samples (
|
||||||
|
benchmark_id, service_id, asset_id, workload, observed_at, results_json
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(benchmark_id) DO UPDATE SET
|
||||||
|
service_id=excluded.service_id,
|
||||||
|
asset_id=excluded.asset_id,
|
||||||
|
workload=excluded.workload,
|
||||||
|
observed_at=excluded.observed_at,
|
||||||
|
results_json=excluded.results_json
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
sample.benchmark_id,
|
||||||
|
sample.service_id,
|
||||||
|
sample.asset_id,
|
||||||
|
sample.workload,
|
||||||
|
sample.observed_at,
|
||||||
|
_json_dumps(sample.results),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return self.list_benchmark_samples()
|
||||||
|
|
||||||
|
def list_benchmark_samples(self, *, service_id: str | None = None, workload: str | None = None) -> list[dict]:
|
||||||
|
query = "SELECT * FROM benchmark_samples"
|
||||||
|
clauses = []
|
||||||
|
params: list[object] = []
|
||||||
|
if service_id:
|
||||||
|
clauses.append("service_id = ?")
|
||||||
|
params.append(service_id)
|
||||||
|
if workload:
|
||||||
|
clauses.append("workload = ?")
|
||||||
|
params.append(workload)
|
||||||
|
if clauses:
|
||||||
|
query += " WHERE " + " AND ".join(clauses)
|
||||||
|
query += " ORDER BY observed_at DESC, benchmark_id"
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(query, params).fetchall()
|
||||||
|
return [self._benchmark_row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
def list_client_models(self) -> list[dict]:
|
def list_client_models(self) -> list[dict]:
|
||||||
services = self.list_services()
|
services = self.list_services()
|
||||||
roles = self.list_roles()
|
roles = self.list_roles()
|
||||||
|
|
@ -243,7 +298,7 @@ class Registry:
|
||||||
"id": asset_id,
|
"id": asset_id,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"owned_by": service["host_id"],
|
"owned_by": service["host_id"],
|
||||||
"geniehive": self._service_metadata(service) | {"route_type": "asset", "asset_id": asset_id},
|
"geniehive": self._service_metadata(service, requested_model=asset_id) | {"route_type": "asset", "asset_id": asset_id},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -282,6 +337,11 @@ class Registry:
|
||||||
best_latency_ms=best_latency_ms,
|
best_latency_ms=best_latency_ms,
|
||||||
),
|
),
|
||||||
"routing_policy": role["routing_policy"],
|
"routing_policy": role["routing_policy"],
|
||||||
|
"effective_request_policy": self._effective_request_policy(
|
||||||
|
requested_model=role["role_id"],
|
||||||
|
role=role,
|
||||||
|
service=self.resolve_route(role["role_id"], kind=role["operation"]).get("service") if matching_services else None,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
@ -331,6 +391,54 @@ class Registry:
|
||||||
service = max(candidates, key=score)
|
service = max(candidates, key=score)
|
||||||
return {"match_type": "role", "role": role, "service": service}
|
return {"match_type": "role", "role": role, "service": service}
|
||||||
|
|
||||||
|
def match_routes(self, request: RouteMatchRequest) -> dict:
|
||||||
|
tasks = [task.strip() for task in ([request.task] if request.task else []) + request.tasks if task and task.strip()]
|
||||||
|
workloads = [value.strip() for value in ([request.workload] if request.workload else []) + request.workloads if value and value.strip()]
|
||||||
|
kind = request.kind
|
||||||
|
modality = request.modality
|
||||||
|
services = [
|
||||||
|
service
|
||||||
|
for service in self.list_services()
|
||||||
|
if (kind is None or service["kind"] == kind)
|
||||||
|
and service["state"].get("accept_requests", True)
|
||||||
|
and service["state"].get("health") == "healthy"
|
||||||
|
]
|
||||||
|
roles = [
|
||||||
|
role
|
||||||
|
for role in self.list_roles()
|
||||||
|
if (kind is None or role["operation"] == kind)
|
||||||
|
and (modality is None or role["modality"] == modality)
|
||||||
|
]
|
||||||
|
|
||||||
|
candidates: list[dict] = []
|
||||||
|
for role in roles:
|
||||||
|
resolved = self.resolve_route(role["role_id"], kind=role["operation"])
|
||||||
|
service = resolved["service"] if resolved is not None else None
|
||||||
|
candidate = self._score_role_candidate(role, service, tasks, workloads)
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
if request.include_direct_services:
|
||||||
|
for service in services:
|
||||||
|
candidates.append(self._score_service_candidate(service, tasks, workloads))
|
||||||
|
|
||||||
|
candidates.sort(
|
||||||
|
key=lambda item: (
|
||||||
|
-item["score"],
|
||||||
|
item["candidate_type"] != "role",
|
||||||
|
item["candidate_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
limit = max(1, request.limit)
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"task_count": len(tasks),
|
||||||
|
"tasks": tasks,
|
||||||
|
"workloads": workloads,
|
||||||
|
"kind": kind,
|
||||||
|
"modality": modality,
|
||||||
|
"candidates": candidates[:limit],
|
||||||
|
}
|
||||||
|
|
||||||
def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None:
|
def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None:
|
||||||
candidates = []
|
candidates = []
|
||||||
for service in self.list_services():
|
for service in self.list_services():
|
||||||
|
|
@ -355,6 +463,163 @@ class Registry:
|
||||||
service = max(candidates, key=score)
|
service = max(candidates, key=score)
|
||||||
return {"service": service}
|
return {"service": service}
|
||||||
|
|
||||||
|
def _score_role_candidate(self, role: dict, service: dict | None, tasks: list[str], workloads: list[str]) -> dict:
|
||||||
|
task_tokens = _tokenize_tasks(tasks)
|
||||||
|
text_parts = [
|
||||||
|
role.get("role_id", ""),
|
||||||
|
role.get("display_name", "") or "",
|
||||||
|
role.get("description", "") or "",
|
||||||
|
(role.get("prompt_policy") or {}).get("system_prompt", "") or "",
|
||||||
|
" ".join((role.get("routing_policy") or {}).get("preferred_families", [])),
|
||||||
|
" ".join((role.get("routing_policy") or {}).get("preferred_labels", [])),
|
||||||
|
]
|
||||||
|
text_score = _overlap_score(task_tokens, _tokenize_text(" ".join(text_parts)))
|
||||||
|
preferred_families = [family.lower() for family in (role.get("routing_policy") or {}).get("preferred_families", [])]
|
||||||
|
family_score = 0.0
|
||||||
|
if service is not None and preferred_families:
|
||||||
|
asset_names = " ".join(asset.get("asset_id", "") for asset in service["assets"]).lower()
|
||||||
|
if any(family in asset_names for family in preferred_families):
|
||||||
|
family_score = 1.0
|
||||||
|
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
|
||||||
|
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
|
||||||
|
|
||||||
|
score = min(1.0, 0.30 * text_score + 0.15 * family_score + 0.30 * runtime_score + 0.25 * benchmark_score)
|
||||||
|
reasons = []
|
||||||
|
if text_score > 0:
|
||||||
|
reasons.append("task text overlaps role description or policy")
|
||||||
|
if family_score > 0:
|
||||||
|
reasons.append("resolved service matches role preferred model family")
|
||||||
|
reasons.extend(runtime_reasons)
|
||||||
|
reasons.extend(benchmark_reasons)
|
||||||
|
if service is None:
|
||||||
|
reasons.append("no healthy service currently resolves for this role")
|
||||||
|
return {
|
||||||
|
"candidate_type": "role",
|
||||||
|
"candidate_id": role["role_id"],
|
||||||
|
"operation": role["operation"],
|
||||||
|
"score": round(score, 4),
|
||||||
|
"reasons": reasons,
|
||||||
|
"signals": {
|
||||||
|
"task_overlap": round(text_score, 4),
|
||||||
|
"preferred_family_match": family_score,
|
||||||
|
**runtime_signals,
|
||||||
|
**benchmark_signals,
|
||||||
|
},
|
||||||
|
"role": role,
|
||||||
|
"service": service,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _score_service_candidate(self, service: dict, tasks: list[str], workloads: list[str]) -> dict:
|
||||||
|
task_tokens = _tokenize_tasks(tasks)
|
||||||
|
service_text = " ".join(
|
||||||
|
[
|
||||||
|
service.get("service_id", ""),
|
||||||
|
service.get("host_id", ""),
|
||||||
|
" ".join(asset.get("asset_id", "") for asset in service.get("assets", [])),
|
||||||
|
" ".join(f"{key} {value}" for key, value in (service.get("runtime") or {}).items() if value),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
text_score = _overlap_score(task_tokens, _tokenize_text(service_text))
|
||||||
|
runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service)
|
||||||
|
benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads)
|
||||||
|
score = min(1.0, 0.20 * text_score + 0.45 * runtime_score + 0.35 * benchmark_score)
|
||||||
|
reasons = []
|
||||||
|
if text_score > 0:
|
||||||
|
reasons.append("task text overlaps service or asset metadata")
|
||||||
|
reasons.extend(runtime_reasons)
|
||||||
|
reasons.extend(benchmark_reasons)
|
||||||
|
return {
|
||||||
|
"candidate_type": "service",
|
||||||
|
"candidate_id": service["service_id"],
|
||||||
|
"operation": service["kind"],
|
||||||
|
"score": round(score, 4),
|
||||||
|
"reasons": reasons,
|
||||||
|
"signals": {
|
||||||
|
"task_overlap": round(text_score, 4),
|
||||||
|
**runtime_signals,
|
||||||
|
**benchmark_signals,
|
||||||
|
},
|
||||||
|
"role": None,
|
||||||
|
"service": service,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _runtime_signals(service: dict | None) -> tuple[float, list[str], dict[str, object]]:
|
||||||
|
if service is None:
|
||||||
|
return 0.0, [], {"loaded": False, "p50_latency_ms": None, "tokens_per_sec": None}
|
||||||
|
loaded = any(asset.get("loaded") for asset in service.get("assets", []))
|
||||||
|
latency = service["observed"].get("p50_latency_ms")
|
||||||
|
tokens_per_sec = service["observed"].get("tokens_per_sec")
|
||||||
|
queue_depth = service["observed"].get("queue_depth")
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
reasons: list[str] = []
|
||||||
|
if loaded:
|
||||||
|
score += 0.35
|
||||||
|
reasons.append("service already has a loaded asset")
|
||||||
|
if latency is not None:
|
||||||
|
if latency <= 1500:
|
||||||
|
score += 0.30
|
||||||
|
reasons.append("low observed latency")
|
||||||
|
elif latency <= 4000:
|
||||||
|
score += 0.18
|
||||||
|
reasons.append("moderate observed latency")
|
||||||
|
else:
|
||||||
|
score += 0.05
|
||||||
|
reasons.append("high but usable latency")
|
||||||
|
if tokens_per_sec is not None:
|
||||||
|
if tokens_per_sec >= 20:
|
||||||
|
score += 0.20
|
||||||
|
reasons.append("good observed throughput")
|
||||||
|
elif tokens_per_sec >= 8:
|
||||||
|
score += 0.10
|
||||||
|
reasons.append("usable observed throughput")
|
||||||
|
if queue_depth is not None and queue_depth > 0:
|
||||||
|
score -= min(0.15, 0.03 * queue_depth)
|
||||||
|
reasons.append("current queue depth reduces suitability")
|
||||||
|
return max(0.0, min(1.0, score)), reasons, {
|
||||||
|
"loaded": loaded,
|
||||||
|
"p50_latency_ms": latency,
|
||||||
|
"tokens_per_sec": tokens_per_sec,
|
||||||
|
"queue_depth": queue_depth,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _benchmark_signals(self, service: dict | None, tasks: list[str], workloads: list[str]) -> tuple[float, list[str], dict[str, object]]:
|
||||||
|
if service is None:
|
||||||
|
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
|
||||||
|
samples = self.list_benchmark_samples(service_id=service["service_id"])
|
||||||
|
if not samples:
|
||||||
|
return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None}
|
||||||
|
|
||||||
|
query_tokens = _tokenize_tasks(tasks + workloads)
|
||||||
|
best_overlap = 0.0
|
||||||
|
best_quality = 0.0
|
||||||
|
matched_count = 0
|
||||||
|
for sample in samples:
|
||||||
|
workload_tokens = _tokenize_text(sample["workload"])
|
||||||
|
overlap = _overlap_score(query_tokens, workload_tokens) if query_tokens else 0.0
|
||||||
|
if workloads and sample["workload"] in workloads:
|
||||||
|
overlap = max(overlap, 1.0)
|
||||||
|
quality = _benchmark_quality_score(sample["results"])
|
||||||
|
if overlap > 0 or not query_tokens:
|
||||||
|
matched_count += 1
|
||||||
|
best_overlap = max(best_overlap, overlap)
|
||||||
|
best_quality = max(best_quality, quality)
|
||||||
|
|
||||||
|
score = 0.55 * best_overlap + 0.45 * best_quality if matched_count else 0.0
|
||||||
|
reasons: list[str] = []
|
||||||
|
if matched_count:
|
||||||
|
reasons.append("recent benchmark sample matches requested workload or task shape")
|
||||||
|
if best_quality >= 0.6:
|
||||||
|
reasons.append("benchmark results indicate strong empirical fit")
|
||||||
|
elif matched_count:
|
||||||
|
reasons.append("benchmark results indicate limited but relevant empirical fit")
|
||||||
|
return min(1.0, score), reasons, {
|
||||||
|
"benchmark_match_count": matched_count,
|
||||||
|
"best_workload_overlap": round(best_overlap, 4),
|
||||||
|
"benchmark_quality_score": round(best_quality, 4) if matched_count else None,
|
||||||
|
}
|
||||||
|
|
||||||
def cluster_health(self, stale_after_s: float) -> dict:
|
def cluster_health(self, stale_after_s: float) -> dict:
|
||||||
hosts = self.list_hosts()
|
hosts = self.list_hosts()
|
||||||
services = self.list_services()
|
services = self.list_services()
|
||||||
|
|
@ -397,9 +662,10 @@ class Registry:
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _service_metadata(self, service: dict) -> dict:
|
def _service_metadata(self, service: dict, *, requested_model: str | None = None) -> dict:
|
||||||
lat = service["observed"].get("p50_latency_ms")
|
lat = service["observed"].get("p50_latency_ms")
|
||||||
loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
|
loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0
|
||||||
|
effective_requested_model = requested_model or service["service_id"]
|
||||||
return {
|
return {
|
||||||
"route_type": "service",
|
"route_type": "service",
|
||||||
"service_id": service["service_id"],
|
"service_id": service["service_id"],
|
||||||
|
|
@ -412,6 +678,10 @@ class Registry:
|
||||||
"assets": service["assets"],
|
"assets": service["assets"],
|
||||||
"runtime": service["runtime"],
|
"runtime": service["runtime"],
|
||||||
"observed": service["observed"],
|
"observed": service["observed"],
|
||||||
|
"effective_request_policy": self._effective_request_policy(
|
||||||
|
requested_model=effective_requested_model,
|
||||||
|
service=service,
|
||||||
|
),
|
||||||
"offload_hint": self._offload_hint(
|
"offload_hint": self._offload_hint(
|
||||||
operation=service["kind"],
|
operation=service["kind"],
|
||||||
loaded_count=loaded_count,
|
loaded_count=loaded_count,
|
||||||
|
|
@ -419,6 +689,23 @@ class Registry:
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _effective_request_policy(
|
||||||
|
*,
|
||||||
|
requested_model: str,
|
||||||
|
service: dict | None,
|
||||||
|
role: dict | None = None,
|
||||||
|
) -> dict | None:
|
||||||
|
if service is None or service.get("kind") != "chat":
|
||||||
|
return None
|
||||||
|
asset = select_target_asset(service, requested_model)
|
||||||
|
return effective_chat_request_policy(
|
||||||
|
requested_model=requested_model,
|
||||||
|
service=service,
|
||||||
|
role=role,
|
||||||
|
asset=asset,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _host_row_to_dict(row: sqlite3.Row) -> dict:
|
def _host_row_to_dict(row: sqlite3.Row) -> dict:
|
||||||
return {
|
return {
|
||||||
|
|
@ -462,3 +749,53 @@ class Registry:
|
||||||
"routing_policy": json.loads(row["routing_policy_json"]),
|
"routing_policy": json.loads(row["routing_policy_json"]),
|
||||||
"updated_at": row["updated_at"],
|
"updated_at": row["updated_at"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _benchmark_row_to_dict(row: sqlite3.Row) -> dict:
|
||||||
|
return {
|
||||||
|
"benchmark_id": row["benchmark_id"],
|
||||||
|
"service_id": row["service_id"],
|
||||||
|
"asset_id": row["asset_id"],
|
||||||
|
"workload": row["workload"],
|
||||||
|
"observed_at": row["observed_at"],
|
||||||
|
"results": json.loads(row["results_json"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize_text(value: str) -> set[str]:
|
||||||
|
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize_tasks(tasks: list[str]) -> set[str]:
|
||||||
|
return _tokenize_text(" ".join(tasks))
|
||||||
|
|
||||||
|
|
||||||
|
def _overlap_score(task_tokens: set[str], candidate_tokens: set[str]) -> float:
|
||||||
|
if not task_tokens or not candidate_tokens:
|
||||||
|
return 0.0
|
||||||
|
overlap = len(task_tokens & candidate_tokens) / max(1, len(task_tokens))
|
||||||
|
return min(1.0, overlap)
|
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_quality_score(results: dict) -> float:
|
||||||
|
if not results:
|
||||||
|
return 0.0
|
||||||
|
quality = 0.0
|
||||||
|
tokens_per_sec = results.get("tokens_per_sec")
|
||||||
|
ttft_ms = results.get("ttft_ms")
|
||||||
|
pass_rate = results.get("pass_rate")
|
||||||
|
quality_score = results.get("quality_score")
|
||||||
|
if isinstance(quality_score, (int, float)):
|
||||||
|
quality = max(quality, max(0.0, min(1.0, float(quality_score))))
|
||||||
|
if isinstance(pass_rate, (int, float)):
|
||||||
|
quality = max(quality, max(0.0, min(1.0, float(pass_rate))))
|
||||||
|
if isinstance(tokens_per_sec, (int, float)):
|
||||||
|
quality += min(0.35, float(tokens_per_sec) / 100.0)
|
||||||
|
if isinstance(ttft_ms, (int, float)):
|
||||||
|
if float(ttft_ms) <= 1000:
|
||||||
|
quality += 0.25
|
||||||
|
elif float(ttft_ms) <= 2500:
|
||||||
|
quality += 0.15
|
||||||
|
else:
|
||||||
|
quality += 0.05
|
||||||
|
return min(1.0, quality)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,128 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def deep_merge_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
merged = copy.deepcopy(payload)
|
||||||
|
for key, value in defaults.items():
|
||||||
|
if key not in merged:
|
||||||
|
merged[key] = copy.deepcopy(value)
|
||||||
|
continue
|
||||||
|
if isinstance(merged[key], dict) and isinstance(value, dict):
|
||||||
|
merged[key] = deep_merge_defaults(merged[key], value)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def deep_merge_override(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
merged = copy.deepcopy(base)
|
||||||
|
for key, value in override.items():
|
||||||
|
if isinstance(merged.get(key), dict) and isinstance(value, dict):
|
||||||
|
merged[key] = deep_merge_override(merged[key], value)
|
||||||
|
else:
|
||||||
|
merged[key] = copy.deepcopy(value)
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def apply_system_prompt(messages: list[dict[str, Any]], prompt: str, position: str) -> list[dict[str, Any]]:
|
||||||
|
if any(message.get("role") == "system" and message.get("content") == prompt for message in messages):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
injected = {"role": "system", "content": prompt}
|
||||||
|
if position == "append":
|
||||||
|
return [*messages, injected]
|
||||||
|
if position == "replace":
|
||||||
|
updated = list(messages)
|
||||||
|
for index, message in enumerate(updated):
|
||||||
|
if message.get("role") == "system":
|
||||||
|
updated[index] = injected
|
||||||
|
return updated
|
||||||
|
return [injected, *updated]
|
||||||
|
return [injected, *messages]
|
||||||
|
|
||||||
|
|
||||||
|
def apply_request_policy(body: dict[str, Any], policy: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
|
if not policy:
|
||||||
|
return dict(body)
|
||||||
|
|
||||||
|
shaped = deep_merge_defaults(body, policy.get("body_defaults", {}) or {})
|
||||||
|
system_prompt = policy.get("system_prompt")
|
||||||
|
if not system_prompt:
|
||||||
|
return shaped
|
||||||
|
|
||||||
|
messages = shaped.get("messages")
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
return shaped
|
||||||
|
|
||||||
|
normalized_messages = [message for message in messages if isinstance(message, dict)]
|
||||||
|
shaped["messages"] = apply_system_prompt(
|
||||||
|
normalized_messages,
|
||||||
|
system_prompt,
|
||||||
|
str(policy.get("system_prompt_position") or "prepend"),
|
||||||
|
)
|
||||||
|
return shaped
|
||||||
|
|
||||||
|
|
||||||
|
def merge_request_policies(*policies: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
combined: dict[str, Any] = {}
|
||||||
|
for policy in policies:
|
||||||
|
if not policy:
|
||||||
|
continue
|
||||||
|
body_defaults = policy.get("body_defaults") or {}
|
||||||
|
if body_defaults:
|
||||||
|
combined["body_defaults"] = deep_merge_override(combined.get("body_defaults", {}), body_defaults)
|
||||||
|
if policy.get("system_prompt"):
|
||||||
|
combined["system_prompt"] = policy["system_prompt"]
|
||||||
|
combined["system_prompt_position"] = policy.get("system_prompt_position", "prepend")
|
||||||
|
return combined or None
|
||||||
|
|
||||||
|
|
||||||
|
def select_target_asset(service: dict[str, Any], requested_model: str) -> dict[str, Any] | None:
|
||||||
|
assets = service.get("assets", [])
|
||||||
|
for asset in assets:
|
||||||
|
if asset.get("asset_id") == requested_model:
|
||||||
|
return asset
|
||||||
|
for asset in assets:
|
||||||
|
if asset.get("loaded") and asset.get("asset_id"):
|
||||||
|
return asset
|
||||||
|
for asset in assets:
|
||||||
|
if asset.get("asset_id"):
|
||||||
|
return asset
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def looks_like_qwen3(value: str) -> bool:
|
||||||
|
normalized = re.sub(r"[^a-z0-9]+", "", value.lower())
|
||||||
|
return normalized.startswith("qwen3")
|
||||||
|
|
||||||
|
|
||||||
|
def infer_chat_request_policy(requested_model: str, service: dict[str, Any], asset: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
|
identifiers = [requested_model, service.get("service_id", "")]
|
||||||
|
if asset is not None:
|
||||||
|
identifiers.append(asset.get("asset_id", ""))
|
||||||
|
identifiers.extend(asset_item.get("asset_id", "") for asset_item in service.get("assets", []))
|
||||||
|
if any(looks_like_qwen3(identifier) for identifier in identifiers if identifier):
|
||||||
|
return {
|
||||||
|
"body_defaults": {
|
||||||
|
"chat_template_kwargs": {
|
||||||
|
"enable_thinking": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def effective_chat_request_policy(
|
||||||
|
*,
|
||||||
|
requested_model: str,
|
||||||
|
service: dict[str, Any],
|
||||||
|
role: dict[str, Any] | None = None,
|
||||||
|
asset: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return merge_request_policies(
|
||||||
|
infer_chat_request_policy(requested_model, service, asset),
|
||||||
|
(asset or {}).get("request_policy"),
|
||||||
|
((role or {}).get("prompt_policy") or {}).get("request_policy"),
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,104 @@
|
||||||
|
from geniehive_control.benchmark_runner import ChatBenchmarkCase, ChatBenchmarkWorkload, built_in_chat_workloads, run_chat_benchmark
|
||||||
|
|
||||||
|
|
||||||
|
def test_built_in_chat_workloads_exist() -> None:
|
||||||
|
workloads = built_in_chat_workloads()
|
||||||
|
|
||||||
|
assert "chat.short_reasoning" in workloads
|
||||||
|
assert workloads["chat.short_reasoning"].cases
|
||||||
|
assert workloads["chat.short_reasoning"].chat_template_kwargs == {"enable_thinking": False}
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_chat_benchmark_generates_report() -> None:
|
||||||
|
workload = ChatBenchmarkWorkload(
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
system_prompt="You are concise.",
|
||||||
|
cases=[
|
||||||
|
ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly."),
|
||||||
|
ChatBenchmarkCase(name="case2", prompt="Explain latency versus throughput briefly."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
|
||||||
|
assert url.endswith("/v1/chat/completions")
|
||||||
|
assert payload["model"] == "general_assistant"
|
||||||
|
assert "chat_template_kwargs" not in payload
|
||||||
|
return {
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "Benchmark response."}}],
|
||||||
|
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
|
||||||
|
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
report = run_chat_benchmark(
|
||||||
|
base_url="http://127.0.0.1:8800",
|
||||||
|
api_key="change-me-client-key",
|
||||||
|
model="general_assistant",
|
||||||
|
workload=workload,
|
||||||
|
request_fn=fake_request,
|
||||||
|
observed_at=1775584000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample = report.samples[0]
|
||||||
|
assert report.source == "geniehive-benchmark-runner"
|
||||||
|
assert sample.workload == "chat.short_reasoning"
|
||||||
|
assert sample.results["case_count"] == 2
|
||||||
|
assert sample.results["pass_rate"] == 1.0
|
||||||
|
assert sample.results["response_rate"] == 1.0
|
||||||
|
assert sample.results["empty_visible_response_rate"] == 0.0
|
||||||
|
assert sample.results["tokens_per_sec"] == 25.0
|
||||||
|
assert sample.observed_at == 1775584000.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_chat_benchmark_treats_reasoning_content_as_a_pass() -> None:
|
||||||
|
workload = ChatBenchmarkWorkload(
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
system_prompt="You are concise.",
|
||||||
|
cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "", "reasoning_content": "Reasoning only."}}],
|
||||||
|
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
|
||||||
|
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
report = run_chat_benchmark(
|
||||||
|
base_url="http://127.0.0.1:8800",
|
||||||
|
api_key="change-me-client-key",
|
||||||
|
model="general_assistant",
|
||||||
|
workload=workload,
|
||||||
|
request_fn=fake_request,
|
||||||
|
observed_at=1775584000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report.samples[0].results["pass_rate"] == 1.0
|
||||||
|
assert report.samples[0].results["empty_visible_response_rate"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_chat_benchmark_includes_chat_template_kwargs_when_configured() -> None:
|
||||||
|
workload = ChatBenchmarkWorkload(
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
system_prompt="You are concise.",
|
||||||
|
chat_template_kwargs={"enable_thinking": False},
|
||||||
|
cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict:
|
||||||
|
assert payload["chat_template_kwargs"] == {"enable_thinking": False}
|
||||||
|
return {
|
||||||
|
"choices": [{"message": {"role": "assistant", "content": "Visible response."}}],
|
||||||
|
"usage": {"prompt_tokens": 20, "completion_tokens": 10},
|
||||||
|
"timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
report = run_chat_benchmark(
|
||||||
|
base_url="http://127.0.0.1:8800",
|
||||||
|
api_key="change-me-client-key",
|
||||||
|
model="general_assistant",
|
||||||
|
workload=workload,
|
||||||
|
request_fn=fake_request,
|
||||||
|
observed_at=1775584000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report.samples[0].results["pass_rate"] == 1.0
|
||||||
|
|
@ -0,0 +1,66 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from geniehive_control.benchmarks import load_benchmark_report, report_to_samples
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_benchmark_report_and_generate_sample_ids(tmp_path: Path) -> None:
|
||||||
|
report_path = tmp_path / "bench.json"
|
||||||
|
report_path.write_text(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"report_id": "p40-short-reasoning",
|
||||||
|
"observed_at": 1775583000.0,
|
||||||
|
"source": "local-smoke",
|
||||||
|
"samples": [
|
||||||
|
{
|
||||||
|
"service_id": "p40-box/chat/gpu1-secondary",
|
||||||
|
"asset_id": "Qwen3.5-9B-Q5_K_M",
|
||||||
|
"workload": "chat.short_reasoning",
|
||||||
|
"results": {
|
||||||
|
"ttft_ms": 900,
|
||||||
|
"tokens_per_sec": 30,
|
||||||
|
"quality_score": 0.9
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""".strip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
report = load_benchmark_report(report_path)
|
||||||
|
samples = report_to_samples(report)
|
||||||
|
|
||||||
|
assert report.report_id == "p40-short-reasoning"
|
||||||
|
assert len(samples) == 1
|
||||||
|
assert samples[0].benchmark_id.startswith("bench_")
|
||||||
|
assert samples[0].observed_at == 1775583000.0
|
||||||
|
assert samples[0].results["source"] == "local-smoke"
|
||||||
|
|
||||||
|
|
||||||
|
def test_report_to_samples_preserves_explicit_sample_ids(tmp_path: Path) -> None:
|
||||||
|
report_path = tmp_path / "bench.json"
|
||||||
|
report_path.write_text(
|
||||||
|
"""
|
||||||
|
{
|
||||||
|
"observed_at": 1775583000.0,
|
||||||
|
"samples": [
|
||||||
|
{
|
||||||
|
"benchmark_id": "bench-explicit",
|
||||||
|
"service_id": "p40-box/chat/gpu0-primary",
|
||||||
|
"workload": "chat.concise_support",
|
||||||
|
"results": {
|
||||||
|
"tokens_per_sec": 24
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""".strip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
report = load_benchmark_report(report_path)
|
||||||
|
samples = report_to_samples(report)
|
||||||
|
|
||||||
|
assert samples[0].benchmark_id == "bench-explicit"
|
||||||
|
assert samples[0].workload == "chat.concise_support"
|
||||||
|
|
@ -158,6 +158,108 @@ def test_proxy_chat_completion_strips_reasoning_fields(tmp_path: Path) -> None:
|
||||||
assert "reasoning" not in choice
|
assert "reasoning" not in choice
|
||||||
|
|
||||||
|
|
||||||
|
def test_proxy_chat_completion_applies_inferred_qwen_request_defaults(tmp_path: Path) -> None:
|
||||||
|
registry = _build_registry(tmp_path)
|
||||||
|
|
||||||
|
class _InspectingPoster:
|
||||||
|
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
||||||
|
assert json["chat_template_kwargs"] == {"enable_thinking": False}
|
||||||
|
return _FakeResponse({"ok": True, "echo_model": json["model"]})
|
||||||
|
|
||||||
|
upstream = UpstreamClient(client=_InspectingPoster())
|
||||||
|
|
||||||
|
async def run() -> dict:
|
||||||
|
return await proxy_chat_completion(
|
||||||
|
{
|
||||||
|
"model": "mentor",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
registry=registry,
|
||||||
|
upstream=upstream,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
assert result["echo_model"] == "qwen3-8b-q4km"
|
||||||
|
|
||||||
|
|
||||||
|
def test_proxy_chat_completion_preserves_explicit_template_kwargs(tmp_path: Path) -> None:
|
||||||
|
registry = _build_registry(tmp_path)
|
||||||
|
|
||||||
|
class _InspectingPoster:
|
||||||
|
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
||||||
|
assert json["chat_template_kwargs"] == {"enable_thinking": True, "foo": "bar"}
|
||||||
|
return _FakeResponse({"ok": True, "echo_model": json["model"]})
|
||||||
|
|
||||||
|
upstream = UpstreamClient(client=_InspectingPoster())
|
||||||
|
|
||||||
|
async def run() -> dict:
|
||||||
|
return await proxy_chat_completion(
|
||||||
|
{
|
||||||
|
"model": "mentor",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
"chat_template_kwargs": {"enable_thinking": True, "foo": "bar"},
|
||||||
|
},
|
||||||
|
registry=registry,
|
||||||
|
upstream=upstream,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
assert result["echo_model"] == "qwen3-8b-q4km"
|
||||||
|
|
||||||
|
|
||||||
|
def test_proxy_chat_completion_applies_asset_request_policy(tmp_path: Path) -> None:
|
||||||
|
registry = Registry(tmp_path / "geniehive.sqlite3")
|
||||||
|
registry.register_host(
|
||||||
|
HostRegistration(
|
||||||
|
host_id="atlas-01",
|
||||||
|
address="192.168.1.101",
|
||||||
|
services=[
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/custom-model",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18091",
|
||||||
|
assets=[
|
||||||
|
{
|
||||||
|
"asset_id": "custom-model-v1",
|
||||||
|
"loaded": True,
|
||||||
|
"request_policy": {
|
||||||
|
"body_defaults": {
|
||||||
|
"temperature": 0.2,
|
||||||
|
"chat_template_kwargs": {"custom_flag": "yes"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 900},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class _InspectingPoster:
|
||||||
|
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
||||||
|
assert json["temperature"] == 0.2
|
||||||
|
assert json["chat_template_kwargs"] == {"custom_flag": "yes"}
|
||||||
|
return _FakeResponse({"ok": True, "echo_model": json["model"]})
|
||||||
|
|
||||||
|
upstream = UpstreamClient(client=_InspectingPoster())
|
||||||
|
|
||||||
|
async def run() -> dict:
|
||||||
|
return await proxy_chat_completion(
|
||||||
|
{
|
||||||
|
"model": "custom-model-v1",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}],
|
||||||
|
},
|
||||||
|
registry=registry,
|
||||||
|
upstream=upstream,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
assert result["echo_model"] == "custom-model-v1"
|
||||||
|
|
||||||
|
|
||||||
def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None:
|
def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None:
|
||||||
registry = _build_registry(tmp_path)
|
registry = _build_registry(tmp_path)
|
||||||
upstream = UpstreamClient(client=_FakePoster())
|
upstream = UpstreamClient(client=_FakePoster())
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from geniehive_control.main import create_app
|
from geniehive_control.main import create_app
|
||||||
from geniehive_control.models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile
|
from geniehive_control.models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest
|
||||||
from geniehive_control.registry import Registry
|
from geniehive_control.registry import Registry
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -133,9 +133,11 @@ def test_registry_persists_roles_and_resolves_direct_and_role_routes(tmp_path: P
|
||||||
mentor = next(item for item in models if item["id"] == "mentor")
|
mentor = next(item for item in models if item["id"] == "mentor")
|
||||||
assert mentor["geniehive"]["route_type"] == "role"
|
assert mentor["geniehive"]["route_type"] == "role"
|
||||||
assert mentor["geniehive"]["offload_hint"]["suitability"] == "good_for_low_complexity"
|
assert mentor["geniehive"]["offload_hint"]["suitability"] == "good_for_low_complexity"
|
||||||
|
assert mentor["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False
|
||||||
asset = next(item for item in models if item["id"] == "qwen3-8b-q4km")
|
asset = next(item for item in models if item["id"] == "qwen3-8b-q4km")
|
||||||
assert asset["geniehive"]["route_type"] == "asset"
|
assert asset["geniehive"]["route_type"] == "asset"
|
||||||
assert asset["geniehive"]["offload_hint"]["recommended_for"] == "lower-complexity offload"
|
assert asset["geniehive"]["offload_hint"]["recommended_for"] == "lower-complexity offload"
|
||||||
|
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_control_app_exposes_expected_routes() -> None:
|
def test_control_app_exposes_expected_routes() -> None:
|
||||||
|
|
@ -147,6 +149,220 @@ def test_control_app_exposes_expected_routes() -> None:
|
||||||
assert "/v1/nodes/heartbeat" in paths
|
assert "/v1/nodes/heartbeat" in paths
|
||||||
assert "/v1/cluster/hosts" in paths
|
assert "/v1/cluster/hosts" in paths
|
||||||
assert "/v1/cluster/services" in paths
|
assert "/v1/cluster/services" in paths
|
||||||
|
assert "/v1/cluster/benchmarks" in paths
|
||||||
assert "/v1/cluster/roles" in paths
|
assert "/v1/cluster/roles" in paths
|
||||||
assert "/v1/cluster/health" in paths
|
assert "/v1/cluster/health" in paths
|
||||||
assert "/v1/cluster/routes/resolve" in paths
|
assert "/v1/cluster/routes/resolve" in paths
|
||||||
|
assert "/v1/cluster/routes/match" in paths
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_can_rank_routes_for_task_statements(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "geniehive.sqlite3"
|
||||||
|
registry = Registry(db_path)
|
||||||
|
|
||||||
|
registry.register_host(
|
||||||
|
HostRegistration(
|
||||||
|
host_id="atlas-01",
|
||||||
|
address="192.168.1.101",
|
||||||
|
services=[
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/qwen-reasoner",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18091",
|
||||||
|
assets=[{"asset_id": "qwen3.5-9b", "loaded": True}],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 1100, "tokens_per_sec": 28},
|
||||||
|
),
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/rocket-background",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18093",
|
||||||
|
assets=[{"asset_id": "rocket-3b", "loaded": True}],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 4200, "tokens_per_sec": 7},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
registry.upsert_roles(
|
||||||
|
[
|
||||||
|
RoleProfile(
|
||||||
|
role_id="general_assistant",
|
||||||
|
display_name="General Assistant",
|
||||||
|
description="Fast general technical assistant for reasoning and question answering.",
|
||||||
|
operation="chat",
|
||||||
|
modality="text",
|
||||||
|
routing_policy={"preferred_families": ["qwen3.5", "qwen"]},
|
||||||
|
),
|
||||||
|
RoleProfile(
|
||||||
|
role_id="background_summarizer",
|
||||||
|
display_name="Background Summarizer",
|
||||||
|
description="Slow fallback summarizer for lower-priority background work.",
|
||||||
|
operation="chat",
|
||||||
|
modality="text",
|
||||||
|
routing_policy={"preferred_families": ["rocket"]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = registry.match_routes(
|
||||||
|
RouteMatchRequest(
|
||||||
|
task="fast technical reasoning for an interactive assistant",
|
||||||
|
kind="chat",
|
||||||
|
modality="text",
|
||||||
|
limit=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["task_count"] == 1
|
||||||
|
assert result["candidates"]
|
||||||
|
top = result["candidates"][0]
|
||||||
|
assert top["candidate_type"] == "role"
|
||||||
|
assert top["candidate_id"] == "general_assistant"
|
||||||
|
assert top["service"]["service_id"] == "atlas-01/chat/qwen-reasoner"
|
||||||
|
assert top["signals"]["preferred_family_match"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_match_can_include_direct_services(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "geniehive.sqlite3"
|
||||||
|
registry = Registry(db_path)
|
||||||
|
registry.register_host(
|
||||||
|
HostRegistration(
|
||||||
|
host_id="atlas-01",
|
||||||
|
address="192.168.1.101",
|
||||||
|
services=[
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/qwen3-8b",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18091",
|
||||||
|
assets=[{"asset_id": "qwen3-8b-q4km", "loaded": True}],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 900, "tokens_per_sec": 40},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = registry.match_routes(
|
||||||
|
RouteMatchRequest(
|
||||||
|
task="qwen model for quick chat",
|
||||||
|
kind="chat",
|
||||||
|
include_direct_services=True,
|
||||||
|
limit=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
direct = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service")
|
||||||
|
assert direct["candidate_id"] == "atlas-01/chat/qwen3-8b"
|
||||||
|
assert direct["service"]["assets"][0]["asset_id"] == "qwen3-8b-q4km"
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_persists_benchmark_samples_and_uses_them_for_matching(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "geniehive.sqlite3"
|
||||||
|
registry = Registry(db_path)
|
||||||
|
registry.register_host(
|
||||||
|
HostRegistration(
|
||||||
|
host_id="atlas-01",
|
||||||
|
address="192.168.1.101",
|
||||||
|
services=[
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/qwen-fast",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18091",
|
||||||
|
assets=[{"asset_id": "qwen3.5-9b", "loaded": True}],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 1500, "tokens_per_sec": 22},
|
||||||
|
),
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/rocket-slow",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18093",
|
||||||
|
assets=[{"asset_id": "rocket-3b", "loaded": True}],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 1200, "tokens_per_sec": 10},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
registry.upsert_benchmark_samples(
|
||||||
|
[
|
||||||
|
BenchmarkSample(
|
||||||
|
benchmark_id="bench-qwen-1",
|
||||||
|
service_id="atlas-01/chat/qwen-fast",
|
||||||
|
asset_id="qwen3.5-9b",
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
observed_at=1000.0,
|
||||||
|
results={"tokens_per_sec": 30, "ttft_ms": 900, "quality_score": 0.9},
|
||||||
|
),
|
||||||
|
BenchmarkSample(
|
||||||
|
benchmark_id="bench-rocket-1",
|
||||||
|
service_id="atlas-01/chat/rocket-slow",
|
||||||
|
asset_id="rocket-3b",
|
||||||
|
workload="chat.short_reasoning",
|
||||||
|
observed_at=1000.0,
|
||||||
|
results={"tokens_per_sec": 9, "ttft_ms": 1900, "quality_score": 0.4},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
samples = registry.list_benchmark_samples(service_id="atlas-01/chat/qwen-fast")
|
||||||
|
assert len(samples) == 1
|
||||||
|
assert samples[0]["benchmark_id"] == "bench-qwen-1"
|
||||||
|
|
||||||
|
result = registry.match_routes(
|
||||||
|
RouteMatchRequest(
|
||||||
|
task="fast short reasoning for chat responses",
|
||||||
|
workloads=["chat.short_reasoning"],
|
||||||
|
kind="chat",
|
||||||
|
include_direct_services=True,
|
||||||
|
limit=4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
top_service = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service")
|
||||||
|
assert top_service["candidate_id"] == "atlas-01/chat/qwen-fast"
|
||||||
|
assert top_service["signals"]["benchmark_match_count"] == 1
|
||||||
|
assert top_service["signals"]["best_workload_overlap"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_exposes_asset_request_policy_in_model_metadata(tmp_path: Path) -> None:
|
||||||
|
db_path = tmp_path / "geniehive.sqlite3"
|
||||||
|
registry = Registry(db_path)
|
||||||
|
registry.register_host(
|
||||||
|
HostRegistration(
|
||||||
|
host_id="atlas-01",
|
||||||
|
address="192.168.1.101",
|
||||||
|
services=[
|
||||||
|
RegisteredService(
|
||||||
|
service_id="atlas-01/chat/custom-model",
|
||||||
|
host_id="atlas-01",
|
||||||
|
kind="chat",
|
||||||
|
endpoint="http://192.168.1.101:18091",
|
||||||
|
assets=[
|
||||||
|
{
|
||||||
|
"asset_id": "custom-model-v1",
|
||||||
|
"loaded": True,
|
||||||
|
"request_policy": {
|
||||||
|
"body_defaults": {
|
||||||
|
"temperature": 0.2,
|
||||||
|
"chat_template_kwargs": {"custom_flag": "yes"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
||||||
|
observed={"p50_latency_ms": 900},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
models = registry.list_client_models()
|
||||||
|
asset = next(item for item in models if item["id"] == "custom-model-v1")
|
||||||
|
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["temperature"] == 0.2
|
||||||
|
assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["custom_flag"] == "yes"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue