diff --git a/README.md b/README.md index 5ed21dc..472263a 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,21 @@ make smoke make health ``` +Benchmark workflow: + +```bash +PYTHONPATH=src python scripts/run_benchmark_workload.py \ + --base-url http://127.0.0.1:8800 \ + --api-key change-me-client-key \ + --model general_assistant \ + --workload chat.short_reasoning \ + --output /tmp/geniehive-bench.json + +PYTHONPATH=src python scripts/ingest_benchmark_report.py /tmp/geniehive-bench.json \ + --base-url http://127.0.0.1:8800 \ + --api-key change-me-client-key +``` + Repository conventions: - local runtime state lives under `state/` and should not be committed diff --git a/docs/schemas.md b/docs/schemas.md index ab4ef61..5e1b201 100644 --- a/docs/schemas.md +++ b/docs/schemas.md @@ -45,6 +45,10 @@ service: assets: - asset_id: "qwen3-8b-q4km" loaded: true + request_policy: + body_defaults: + chat_template_kwargs: + enable_thinking: false state: health: "healthy" load_state: "loaded" @@ -84,6 +88,9 @@ role: prompt_policy: system_prompt: "You guide without doing the user's work for them." user_template: "{{ user_input }}" + request_policy: + body_defaults: + temperature: 0.2 routing_policy: preferred_families: ["Qwen3", "Mistral"] preferred_labels: ["instruction", "stable"] @@ -92,6 +99,35 @@ role: fallback_roles: ["general_assistant"] ``` +## Request Shape Policy + +This is a general representation for model- or route-specific request shaping. + +```yaml +request_shape_policy: + body_defaults: + chat_template_kwargs: + enable_thinking: false + temperature: 0.2 + system_prompt: "Return only visible final answer text." + system_prompt_position: "prepend" +``` + +Use it for: + +- model-specific request flags such as `chat_template_kwargs.enable_thinking` +- default OpenAI-compatible body fields that should be applied unless the caller already set them +- model-specific prompt instructions that should be prepended, appended, or replace an existing system message + +GenieHive currently supports this policy on: + +- `service.assets[].request_policy` +- `role.prompt_policy.request_policy` + +The control plane may also infer built-in request policies from model family metadata. For example, Qwen3/Qwen3.5 chat routes default to `chat_template_kwargs.enable_thinking: false` unless the caller explicitly sets a different value. + +`GET /v1/models` exposes the merged result as `geniehive.effective_request_policy` on service, asset, and role-backed model entries so clients can discover what GenieHive will apply by default. + ## Health Sample ```yaml @@ -126,3 +162,109 @@ benchmark_sample: ttft_ms: 780 tokens_per_sec: 44 ``` + +## Route Match Request + +```yaml +route_match_request: + task: "fast technical reasoning for an interactive assistant" + tasks: + - "interactive debugging help" + - "concise technical explanations" + workload: "chat.short_reasoning" + workloads: + - "chat.short_reasoning" + - "chat.concise_support" + kind: "chat" + modality: "text" + include_direct_services: true + limit: 5 +``` + +This request is meant to answer: + +- which role-backed route is the best current fit for this task or task suite +- which direct services also look suitable right now + +V1 matching is metadata- and runtime-driven. It uses: + +- role text and routing policy overlap +- service asset and runtime metadata overlap +- loaded state +- observed latency +- observed throughput +- current queue depth when available +- recent benchmark sample workload overlap and empirical quality/performance hints + +If benchmark samples exist for a candidate service, workload hints such as `chat.short_reasoning` can boost routes with recent empirical fit. + +## Route Match Candidate + +```yaml +route_match_candidate: + candidate_type: "role" + candidate_id: "general_assistant" + operation: "chat" + score: 0.86 + reasons: + - "task text overlaps role description or policy" + - "resolved service matches role preferred model family" + - "service already has a loaded asset" + - "low observed latency" + - "good observed throughput" + signals: + task_overlap: 0.33 + preferred_family_match: 1.0 + loaded: true + p50_latency_ms: 1100 + tokens_per_sec: 28 + queue_depth: 0 + benchmark_match_count: 2 + best_workload_overlap: 1.0 + benchmark_quality_score: 0.9 + role: + role_id: "general_assistant" + service: + service_id: "p40-box/chat/gpu1-secondary" +``` + +## Benchmark Ingest Request + +```yaml +benchmark_ingest_request: + samples: + - benchmark_id: "bench-qwen-1" + service_id: "p40-box/chat/gpu1-secondary" + asset_id: "Qwen3.5-9B-Q5_K_M" + workload: "chat.short_reasoning" + observed_at: 1775582000.0 + results: + ttft_ms: 900 + tokens_per_sec: 30 + quality_score: 0.9 +``` + +## Benchmark Report File + +This is a file-oriented format meant for repeatable benchmark runs before ingestion into GenieHive. + +```yaml +benchmark_report: + report_id: "p40-short-reasoning" + observed_at: 1775583000.0 + source: "local-smoke" + samples: + - service_id: "p40-box/chat/gpu1-secondary" + asset_id: "Qwen3.5-9B-Q5_K_M" + workload: "chat.short_reasoning" + results: + ttft_ms: 900 + tokens_per_sec: 30 + quality_score: 0.9 +``` + +Notes: + +- `observed_at` may be set once at the report level or per sample +- `benchmark_id` is optional in the file format; GenieHive tooling can generate a stable ID during conversion +- the helper script `scripts/ingest_benchmark_report.py` loads this format and posts the expanded samples to `POST /v1/cluster/benchmarks` diff --git a/scripts/ingest_benchmark_report.py b/scripts/ingest_benchmark_report.py new file mode 100644 index 0000000..27f1d93 --- /dev/null +++ b/scripts/ingest_benchmark_report.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import os + +import httpx + +from geniehive_control.benchmarks import load_benchmark_report, report_to_samples + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Load a benchmark report JSON file and ingest its samples into GenieHive.") + parser.add_argument("input", help="Path to benchmark report JSON") + parser.add_argument( + "--base-url", + default=os.environ.get("GENIEHIVE_CONTROL_BASE_URL", "http://127.0.0.1:8800"), + help="GenieHive control base URL", + ) + parser.add_argument( + "--api-key", + default=os.environ.get("GENIEHIVE_CLIENT_API_KEY", "change-me-client-key"), + help="GenieHive client API key", + ) + return parser + + +def main() -> int: + args = build_parser().parse_args() + report = load_benchmark_report(args.input) + samples = report_to_samples(report) + payload = {"samples": [sample.model_dump() for sample in samples]} + response = httpx.post( + args.base_url.rstrip("/") + "/v1/cluster/benchmarks", + json=payload, + headers={"X-Api-Key": args.api_key}, + timeout=30.0, + ) + response.raise_for_status() + print(json.dumps(response.json(), indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/run_benchmark_workload.py b/scripts/run_benchmark_workload.py new file mode 100644 index 0000000..c1ae3fc --- /dev/null +++ b/scripts/run_benchmark_workload.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from geniehive_control.benchmark_runner import built_in_chat_workloads, run_chat_benchmark + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a built-in chat benchmark workload against a GenieHive route or model.") + parser.add_argument("--base-url", default="http://127.0.0.1:8800", help="GenieHive control base URL") + parser.add_argument("--api-key", default="change-me-client-key", help="GenieHive client API key") + parser.add_argument("--model", required=True, help="Role, service, or direct asset/model id to benchmark") + parser.add_argument( + "--workload", + required=True, + choices=sorted(built_in_chat_workloads().keys()), + help="Built-in benchmark workload to run", + ) + parser.add_argument("--output", help="Optional path to write the benchmark report JSON") + return parser + + +def main() -> int: + args = build_parser().parse_args() + workload = built_in_chat_workloads()[args.workload] + report = run_chat_benchmark( + base_url=args.base_url, + api_key=args.api_key, + model=args.model, + workload=workload, + ) + rendered = json.dumps(report.model_dump(), indent=2, sort_keys=True) + if args.output: + Path(args.output).write_text(rendered + "\n", encoding="utf-8") + else: + print(rendered) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/geniehive_control/benchmark_runner.py b/src/geniehive_control/benchmark_runner.py new file mode 100644 index 0000000..907086d --- /dev/null +++ b/src/geniehive_control/benchmark_runner.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Any, Callable + +import httpx + +from .benchmarks import BenchmarkReport, BenchmarkReportSample + + +@dataclass(slots=True) +class ChatBenchmarkCase: + name: str + prompt: str + max_completion_tokens: int = 120 + + +@dataclass(slots=True) +class ChatBenchmarkWorkload: + workload: str + system_prompt: str + cases: list[ChatBenchmarkCase] + chat_template_kwargs: dict[str, Any] | None = None + + +def built_in_chat_workloads() -> dict[str, ChatBenchmarkWorkload]: + return { + "chat.short_reasoning": ChatBenchmarkWorkload( + workload="chat.short_reasoning", + system_prompt=( + "You are a concise and careful reasoning assistant. " + "Return only a visible final answer. " + "Do not spend tokens on hidden reasoning or an internal monologue." + ), + chat_template_kwargs={"enable_thinking": False}, + cases=[ + ChatBenchmarkCase( + name="short_reasoning_1", + prompt="In two short paragraphs, explain why a loaded healthy route should be preferred over a cold route. Return only the final answer text.", + ), + ChatBenchmarkCase( + name="short_reasoning_2", + prompt="Give a compact tradeoff summary for a service with lower latency but worse throughput than another. Return only the final answer text.", + ), + ], + ), + "chat.concise_support": ChatBenchmarkWorkload( + workload="chat.concise_support", + system_prompt="You are a concise support assistant. Return only a visible final answer.", + cases=[ + ChatBenchmarkCase( + name="concise_support_1", + prompt="Reply with a short troubleshooting checklist for a local API endpoint returning 404. Return only the final answer text.", + ), + ChatBenchmarkCase( + name="concise_support_2", + prompt="Reply with a short checklist for checking a tmux-managed service that exited at startup. Return only the final answer text.", + ), + ], + ), + } + + +def run_chat_benchmark( + *, + base_url: str, + api_key: str, + model: str, + workload: ChatBenchmarkWorkload, + request_fn: Callable[[str, dict[str, str], dict[str, Any]], dict[str, Any]] | None = None, + observed_at: float | None = None, +) -> BenchmarkReport: + request = request_fn or _default_chat_request + latencies_ms: list[float] = [] + ttfts_ms: list[float] = [] + tokens_per_sec_values: list[float] = [] + completion_tokens: list[int] = [] + prompt_tokens: list[int] = [] + passed = 0 + responses_received = 0 + empty_visible_responses = 0 + + for case in workload.cases: + start = time.perf_counter() + request_payload = { + "model": model, + "messages": [ + {"role": "system", "content": workload.system_prompt}, + {"role": "user", "content": case.prompt}, + ], + "max_tokens": case.max_completion_tokens, + } + if workload.chat_template_kwargs: + request_payload["chat_template_kwargs"] = dict(workload.chat_template_kwargs) + + payload = request( + base_url.rstrip("/") + "/v1/chat/completions", + { + "X-Api-Key": api_key, + "Content-Type": "application/json", + }, + request_payload, + ) + elapsed_ms = (time.perf_counter() - start) * 1000.0 + latencies_ms.append(elapsed_ms) + + usage = payload.get("usage", {}) + timings = payload.get("timings", {}) + prompt_tokens.append(int(usage.get("prompt_tokens", 0) or 0)) + completion_tokens.append(int(usage.get("completion_tokens", 0) or 0)) + responses_received += 1 + if isinstance(timings.get("prompt_ms"), (int, float)): + ttfts_ms.append(float(timings["prompt_ms"])) + if isinstance(timings.get("predicted_per_second"), (int, float)): + tokens_per_sec_values.append(float(timings["predicted_per_second"])) + elif isinstance(timings.get("predicted_ms"), (int, float)) and completion_tokens[-1] > 0 and float(timings["predicted_ms"]) > 0: + tokens_per_sec_values.append((completion_tokens[-1] * 1000.0) / float(timings["predicted_ms"])) + if _has_nonempty_content(payload): + passed += 1 + else: + empty_visible_responses += 1 + + sample = BenchmarkReportSample( + service_id=model, + workload=workload.workload, + observed_at=observed_at or time.time(), + results={ + "case_count": len(workload.cases), + "pass_rate": passed / max(1, len(workload.cases)), + "response_rate": responses_received / max(1, len(workload.cases)), + "empty_visible_response_rate": empty_visible_responses / max(1, len(workload.cases)), + "p50_latency_ms": _median(latencies_ms), + "ttft_ms": _median(ttfts_ms) if ttfts_ms else None, + "tokens_per_sec": _median(tokens_per_sec_values) if tokens_per_sec_values else None, + "prompt_tokens": int(sum(prompt_tokens) / max(1, len(prompt_tokens))), + "completion_tokens": int(sum(completion_tokens) / max(1, len(completion_tokens))), + }, + ) + return BenchmarkReport( + report_id=f"{model}-{workload.workload}", + observed_at=sample.observed_at, + source="geniehive-benchmark-runner", + samples=[sample], + ) + + +def _default_chat_request(url: str, headers: dict[str, str], payload: dict[str, Any]) -> dict[str, Any]: + with httpx.Client(timeout=120.0) as client: + response = client.post(url, headers=headers, json=payload) + response.raise_for_status() + return response.json() + + +def _median(values: list[float]) -> float | None: + if not values: + return None + ordered = sorted(values) + middle = len(ordered) // 2 + if len(ordered) % 2: + return ordered[middle] + return (ordered[middle - 1] + ordered[middle]) / 2.0 + + +def _has_nonempty_content(payload: dict[str, Any]) -> bool: + choices = payload.get("choices", []) + if not choices: + return False + message = choices[0].get("message", {}) + if str(message.get("content", "")).strip(): + return True + return bool(str(message.get("reasoning_content", "")).strip()) diff --git a/src/geniehive_control/benchmarks.py b/src/geniehive_control/benchmarks.py new file mode 100644 index 0000000..2cbddf8 --- /dev/null +++ b/src/geniehive_control/benchmarks.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import hashlib +import json +from pathlib import Path + +from pydantic import BaseModel, Field + +from .models import BenchmarkSample + + +class BenchmarkReportSample(BaseModel): + service_id: str + asset_id: str | None = None + workload: str + observed_at: float | None = None + benchmark_id: str | None = None + results: dict[str, object] = Field(default_factory=dict) + + +class BenchmarkReport(BaseModel): + report_id: str | None = None + observed_at: float | None = None + source: str | None = None + samples: list[BenchmarkReportSample] = Field(default_factory=list) + + +def load_benchmark_report(path: str | Path) -> BenchmarkReport: + raw = json.loads(Path(path).read_text(encoding="utf-8")) + return BenchmarkReport.model_validate(raw) + + +def report_to_samples(report: BenchmarkReport) -> list[BenchmarkSample]: + samples: list[BenchmarkSample] = [] + for index, sample in enumerate(report.samples): + observed_at = sample.observed_at if sample.observed_at is not None else report.observed_at + if observed_at is None: + raise ValueError("Benchmark report sample is missing observed_at and report has no default observed_at.") + benchmark_id = sample.benchmark_id or _make_benchmark_id(report.report_id, sample, index, observed_at) + payload = dict(sample.results) + if report.source and "source" not in payload: + payload["source"] = report.source + samples.append( + BenchmarkSample( + benchmark_id=benchmark_id, + service_id=sample.service_id, + asset_id=sample.asset_id, + workload=sample.workload, + observed_at=observed_at, + results=payload, + ) + ) + return samples + + +def _make_benchmark_id(report_id: str | None, sample: BenchmarkReportSample, index: int, observed_at: float) -> str: + digest = hashlib.sha1( + "|".join( + [ + report_id or "report", + sample.service_id, + sample.asset_id or "", + sample.workload, + str(observed_at), + str(index), + ] + ).encode("utf-8") + ).hexdigest()[:16] + return f"bench_{digest}" diff --git a/src/geniehive_control/chat.py b/src/geniehive_control/chat.py index f40b6de..86ebffe 100644 --- a/src/geniehive_control/chat.py +++ b/src/geniehive_control/chat.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Any +from .request_policy import apply_request_policy, effective_chat_request_policy, select_target_asset from .registry import Registry from .routing import choose_upstream_model_id from .upstream import UpstreamClient @@ -26,7 +27,6 @@ def _strip_reasoning_fields(payload: Any) -> Any: cleaned[key] = _strip_reasoning_fields(value) return cleaned - async def proxy_chat_completion( body: dict[str, Any], *, @@ -45,7 +45,16 @@ async def proxy_chat_completion( if service is None: raise ProxyError(f"No healthy chat target available for '{requested_model}'.", status_code=503) - upstream_body = dict(body) + asset = select_target_asset(service, requested_model) + role = resolved.get("role") + combined_policy = effective_chat_request_policy( + requested_model=requested_model, + service=service, + role=role, + asset=asset, + ) + + upstream_body = apply_request_policy(dict(body), combined_policy) upstream_body["model"] = choose_upstream_model_id(requested_model, service) response = await upstream.chat_completions(service["endpoint"], upstream_body) return _strip_reasoning_fields(response) diff --git a/src/geniehive_control/main.py b/src/geniehive_control/main.py index 49060ab..4a31192 100644 --- a/src/geniehive_control/main.py +++ b/src/geniehive_control/main.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from .auth import require_client_auth, require_node_auth from .chat import ProxyError, proxy_chat_completion, proxy_embeddings from .config import ControlConfig, load_config -from .models import HostHeartbeat, HostRegistration +from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, RouteMatchRequest, RouteMatchResponse from .roles import load_role_catalog from .registry import Registry from .upstream import UpstreamClient, UpstreamError @@ -105,6 +105,23 @@ def create_app( async def list_services(request: Request, _=Depends(require_client_auth)) -> dict: return {"object": "list", "data": request.app.state.registry.list_services()} + @app.get("/v1/cluster/benchmarks") + async def list_benchmarks( + request: Request, + service_id: str | None = None, + workload: str | None = None, + _=Depends(require_client_auth), + ) -> dict: + return { + "object": "list", + "data": request.app.state.registry.list_benchmark_samples(service_id=service_id, workload=workload), + } + + @app.post("/v1/cluster/benchmarks") + async def ingest_benchmarks(payload: BenchmarkIngestRequest, request: Request, _=Depends(require_client_auth)) -> dict: + samples = request.app.state.registry.upsert_benchmark_samples(payload.samples) + return {"status": "ok", "count": len(payload.samples), "data": samples} + @app.get("/v1/cluster/roles") async def list_roles(request: Request, _=Depends(require_client_auth)) -> dict: return {"object": "list", "data": request.app.state.registry.list_roles()} @@ -121,6 +138,11 @@ def create_app( return JSONResponse(status_code=404, content={"error": "no_route", "model": model, "kind": kind}) return {"status": "ok", "resolution": resolved} + @app.post("/v1/cluster/routes/match") + async def match_routes(payload: RouteMatchRequest, request: Request, _=Depends(require_client_auth)) -> dict: + response = request.app.state.registry.match_routes(payload) + return RouteMatchResponse.model_validate(response).model_dump() + return app diff --git a/src/geniehive_control/models.py b/src/geniehive_control/models.py index 660e8b1..bd511a5 100644 --- a/src/geniehive_control/models.py +++ b/src/geniehive_control/models.py @@ -5,9 +5,16 @@ from typing import Any, Literal from pydantic import BaseModel, Field +class RequestShapePolicy(BaseModel): + body_defaults: dict[str, Any] = Field(default_factory=dict) + system_prompt: str | None = None + system_prompt_position: Literal["prepend", "append", "replace"] = "prepend" + + class ServiceAsset(BaseModel): asset_id: str loaded: bool = False + request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy) class ServiceRuntime(BaseModel): @@ -66,6 +73,7 @@ class HostHeartbeat(BaseModel): class PromptPolicy(BaseModel): system_prompt: str | None = None user_template: str | None = None + request_policy: RequestShapePolicy = Field(default_factory=RequestShapePolicy) class RoutingPolicy(BaseModel): @@ -88,3 +96,48 @@ class RoleProfile(BaseModel): class RoleCatalog(BaseModel): roles: list[RoleProfile] = Field(default_factory=list) + + +class BenchmarkSample(BaseModel): + benchmark_id: str + service_id: str + asset_id: str | None = None + workload: str + observed_at: float + results: dict[str, Any] = Field(default_factory=dict) + + +class BenchmarkIngestRequest(BaseModel): + samples: list[BenchmarkSample] = Field(default_factory=list) + + +class RouteMatchRequest(BaseModel): + task: str | None = None + tasks: list[str] = Field(default_factory=list) + workload: str | None = None + workloads: list[str] = Field(default_factory=list) + kind: Literal["chat", "embeddings", "transcription"] | None = None + modality: str | None = None + include_direct_services: bool = True + limit: int = 10 + + +class RouteMatchCandidate(BaseModel): + candidate_type: Literal["role", "service"] + candidate_id: str + operation: Literal["chat", "embeddings", "transcription"] + score: float + reasons: list[str] = Field(default_factory=list) + signals: dict[str, Any] = Field(default_factory=dict) + role: dict[str, Any] | None = None + service: dict[str, Any] | None = None + + +class RouteMatchResponse(BaseModel): + status: str = "ok" + task_count: int + tasks: list[str] + workloads: list[str] = Field(default_factory=list) + kind: Literal["chat", "embeddings", "transcription"] | None = None + modality: str | None = None + candidates: list[RouteMatchCandidate] = Field(default_factory=list) diff --git a/src/geniehive_control/registry.py b/src/geniehive_control/registry.py index 4d17477..a6ca72f 100644 --- a/src/geniehive_control/registry.py +++ b/src/geniehive_control/registry.py @@ -1,11 +1,13 @@ from __future__ import annotations import json +import re import sqlite3 import time from pathlib import Path -from .models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile +from .models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest +from .request_policy import effective_chat_request_policy, select_target_asset def _json_dumps(value: object) -> str: @@ -63,6 +65,15 @@ class Registry: routing_policy_json TEXT NOT NULL, updated_at REAL NOT NULL ); + + CREATE TABLE IF NOT EXISTS benchmark_samples ( + benchmark_id TEXT PRIMARY KEY, + service_id TEXT NOT NULL, + asset_id TEXT, + workload TEXT NOT NULL, + observed_at REAL NOT NULL, + results_json TEXT NOT NULL + ); """ ) @@ -217,6 +228,50 @@ class Registry: rows = conn.execute("SELECT * FROM services ORDER BY host_id, service_id").fetchall() return [self._service_row_to_dict(row) for row in rows] + def upsert_benchmark_samples(self, samples: list[BenchmarkSample]) -> list[dict]: + with self._connect() as conn: + for sample in samples: + conn.execute( + """ + INSERT INTO benchmark_samples ( + benchmark_id, service_id, asset_id, workload, observed_at, results_json + ) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(benchmark_id) DO UPDATE SET + service_id=excluded.service_id, + asset_id=excluded.asset_id, + workload=excluded.workload, + observed_at=excluded.observed_at, + results_json=excluded.results_json + """, + ( + sample.benchmark_id, + sample.service_id, + sample.asset_id, + sample.workload, + sample.observed_at, + _json_dumps(sample.results), + ), + ) + return self.list_benchmark_samples() + + def list_benchmark_samples(self, *, service_id: str | None = None, workload: str | None = None) -> list[dict]: + query = "SELECT * FROM benchmark_samples" + clauses = [] + params: list[object] = [] + if service_id: + clauses.append("service_id = ?") + params.append(service_id) + if workload: + clauses.append("workload = ?") + params.append(workload) + if clauses: + query += " WHERE " + " AND ".join(clauses) + query += " ORDER BY observed_at DESC, benchmark_id" + with self._connect() as conn: + rows = conn.execute(query, params).fetchall() + return [self._benchmark_row_to_dict(row) for row in rows] + def list_client_models(self) -> list[dict]: services = self.list_services() roles = self.list_roles() @@ -243,7 +298,7 @@ class Registry: "id": asset_id, "object": "model", "owned_by": service["host_id"], - "geniehive": self._service_metadata(service) | {"route_type": "asset", "asset_id": asset_id}, + "geniehive": self._service_metadata(service, requested_model=asset_id) | {"route_type": "asset", "asset_id": asset_id}, } ) @@ -282,6 +337,11 @@ class Registry: best_latency_ms=best_latency_ms, ), "routing_policy": role["routing_policy"], + "effective_request_policy": self._effective_request_policy( + requested_model=role["role_id"], + role=role, + service=self.resolve_route(role["role_id"], kind=role["operation"]).get("service") if matching_services else None, + ), }, } ) @@ -331,6 +391,54 @@ class Registry: service = max(candidates, key=score) return {"match_type": "role", "role": role, "service": service} + def match_routes(self, request: RouteMatchRequest) -> dict: + tasks = [task.strip() for task in ([request.task] if request.task else []) + request.tasks if task and task.strip()] + workloads = [value.strip() for value in ([request.workload] if request.workload else []) + request.workloads if value and value.strip()] + kind = request.kind + modality = request.modality + services = [ + service + for service in self.list_services() + if (kind is None or service["kind"] == kind) + and service["state"].get("accept_requests", True) + and service["state"].get("health") == "healthy" + ] + roles = [ + role + for role in self.list_roles() + if (kind is None or role["operation"] == kind) + and (modality is None or role["modality"] == modality) + ] + + candidates: list[dict] = [] + for role in roles: + resolved = self.resolve_route(role["role_id"], kind=role["operation"]) + service = resolved["service"] if resolved is not None else None + candidate = self._score_role_candidate(role, service, tasks, workloads) + candidates.append(candidate) + + if request.include_direct_services: + for service in services: + candidates.append(self._score_service_candidate(service, tasks, workloads)) + + candidates.sort( + key=lambda item: ( + -item["score"], + item["candidate_type"] != "role", + item["candidate_id"], + ) + ) + limit = max(1, request.limit) + return { + "status": "ok", + "task_count": len(tasks), + "tasks": tasks, + "workloads": workloads, + "kind": kind, + "modality": modality, + "candidates": candidates[:limit], + } + def _resolve_direct(self, requested_model: str, *, kind: str | None = None) -> dict | None: candidates = [] for service in self.list_services(): @@ -355,6 +463,163 @@ class Registry: service = max(candidates, key=score) return {"service": service} + def _score_role_candidate(self, role: dict, service: dict | None, tasks: list[str], workloads: list[str]) -> dict: + task_tokens = _tokenize_tasks(tasks) + text_parts = [ + role.get("role_id", ""), + role.get("display_name", "") or "", + role.get("description", "") or "", + (role.get("prompt_policy") or {}).get("system_prompt", "") or "", + " ".join((role.get("routing_policy") or {}).get("preferred_families", [])), + " ".join((role.get("routing_policy") or {}).get("preferred_labels", [])), + ] + text_score = _overlap_score(task_tokens, _tokenize_text(" ".join(text_parts))) + preferred_families = [family.lower() for family in (role.get("routing_policy") or {}).get("preferred_families", [])] + family_score = 0.0 + if service is not None and preferred_families: + asset_names = " ".join(asset.get("asset_id", "") for asset in service["assets"]).lower() + if any(family in asset_names for family in preferred_families): + family_score = 1.0 + runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service) + benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads) + + score = min(1.0, 0.30 * text_score + 0.15 * family_score + 0.30 * runtime_score + 0.25 * benchmark_score) + reasons = [] + if text_score > 0: + reasons.append("task text overlaps role description or policy") + if family_score > 0: + reasons.append("resolved service matches role preferred model family") + reasons.extend(runtime_reasons) + reasons.extend(benchmark_reasons) + if service is None: + reasons.append("no healthy service currently resolves for this role") + return { + "candidate_type": "role", + "candidate_id": role["role_id"], + "operation": role["operation"], + "score": round(score, 4), + "reasons": reasons, + "signals": { + "task_overlap": round(text_score, 4), + "preferred_family_match": family_score, + **runtime_signals, + **benchmark_signals, + }, + "role": role, + "service": service, + } + + def _score_service_candidate(self, service: dict, tasks: list[str], workloads: list[str]) -> dict: + task_tokens = _tokenize_tasks(tasks) + service_text = " ".join( + [ + service.get("service_id", ""), + service.get("host_id", ""), + " ".join(asset.get("asset_id", "") for asset in service.get("assets", [])), + " ".join(f"{key} {value}" for key, value in (service.get("runtime") or {}).items() if value), + ] + ) + text_score = _overlap_score(task_tokens, _tokenize_text(service_text)) + runtime_score, runtime_reasons, runtime_signals = self._runtime_signals(service) + benchmark_score, benchmark_reasons, benchmark_signals = self._benchmark_signals(service, tasks, workloads) + score = min(1.0, 0.20 * text_score + 0.45 * runtime_score + 0.35 * benchmark_score) + reasons = [] + if text_score > 0: + reasons.append("task text overlaps service or asset metadata") + reasons.extend(runtime_reasons) + reasons.extend(benchmark_reasons) + return { + "candidate_type": "service", + "candidate_id": service["service_id"], + "operation": service["kind"], + "score": round(score, 4), + "reasons": reasons, + "signals": { + "task_overlap": round(text_score, 4), + **runtime_signals, + **benchmark_signals, + }, + "role": None, + "service": service, + } + + @staticmethod + def _runtime_signals(service: dict | None) -> tuple[float, list[str], dict[str, object]]: + if service is None: + return 0.0, [], {"loaded": False, "p50_latency_ms": None, "tokens_per_sec": None} + loaded = any(asset.get("loaded") for asset in service.get("assets", [])) + latency = service["observed"].get("p50_latency_ms") + tokens_per_sec = service["observed"].get("tokens_per_sec") + queue_depth = service["observed"].get("queue_depth") + + score = 0.0 + reasons: list[str] = [] + if loaded: + score += 0.35 + reasons.append("service already has a loaded asset") + if latency is not None: + if latency <= 1500: + score += 0.30 + reasons.append("low observed latency") + elif latency <= 4000: + score += 0.18 + reasons.append("moderate observed latency") + else: + score += 0.05 + reasons.append("high but usable latency") + if tokens_per_sec is not None: + if tokens_per_sec >= 20: + score += 0.20 + reasons.append("good observed throughput") + elif tokens_per_sec >= 8: + score += 0.10 + reasons.append("usable observed throughput") + if queue_depth is not None and queue_depth > 0: + score -= min(0.15, 0.03 * queue_depth) + reasons.append("current queue depth reduces suitability") + return max(0.0, min(1.0, score)), reasons, { + "loaded": loaded, + "p50_latency_ms": latency, + "tokens_per_sec": tokens_per_sec, + "queue_depth": queue_depth, + } + + def _benchmark_signals(self, service: dict | None, tasks: list[str], workloads: list[str]) -> tuple[float, list[str], dict[str, object]]: + if service is None: + return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None} + samples = self.list_benchmark_samples(service_id=service["service_id"]) + if not samples: + return 0.0, [], {"benchmark_match_count": 0, "best_workload_overlap": 0.0, "benchmark_quality_score": None} + + query_tokens = _tokenize_tasks(tasks + workloads) + best_overlap = 0.0 + best_quality = 0.0 + matched_count = 0 + for sample in samples: + workload_tokens = _tokenize_text(sample["workload"]) + overlap = _overlap_score(query_tokens, workload_tokens) if query_tokens else 0.0 + if workloads and sample["workload"] in workloads: + overlap = max(overlap, 1.0) + quality = _benchmark_quality_score(sample["results"]) + if overlap > 0 or not query_tokens: + matched_count += 1 + best_overlap = max(best_overlap, overlap) + best_quality = max(best_quality, quality) + + score = 0.55 * best_overlap + 0.45 * best_quality if matched_count else 0.0 + reasons: list[str] = [] + if matched_count: + reasons.append("recent benchmark sample matches requested workload or task shape") + if best_quality >= 0.6: + reasons.append("benchmark results indicate strong empirical fit") + elif matched_count: + reasons.append("benchmark results indicate limited but relevant empirical fit") + return min(1.0, score), reasons, { + "benchmark_match_count": matched_count, + "best_workload_overlap": round(best_overlap, 4), + "benchmark_quality_score": round(best_quality, 4) if matched_count else None, + } + def cluster_health(self, stale_after_s: float) -> dict: hosts = self.list_hosts() services = self.list_services() @@ -397,9 +662,10 @@ class Registry: }, } - def _service_metadata(self, service: dict) -> dict: + def _service_metadata(self, service: dict, *, requested_model: str | None = None) -> dict: lat = service["observed"].get("p50_latency_ms") loaded_count = 1 if any(asset.get("loaded") for asset in service["assets"]) else 0 + effective_requested_model = requested_model or service["service_id"] return { "route_type": "service", "service_id": service["service_id"], @@ -412,6 +678,10 @@ class Registry: "assets": service["assets"], "runtime": service["runtime"], "observed": service["observed"], + "effective_request_policy": self._effective_request_policy( + requested_model=effective_requested_model, + service=service, + ), "offload_hint": self._offload_hint( operation=service["kind"], loaded_count=loaded_count, @@ -419,6 +689,23 @@ class Registry: ), } + @staticmethod + def _effective_request_policy( + *, + requested_model: str, + service: dict | None, + role: dict | None = None, + ) -> dict | None: + if service is None or service.get("kind") != "chat": + return None + asset = select_target_asset(service, requested_model) + return effective_chat_request_policy( + requested_model=requested_model, + service=service, + role=role, + asset=asset, + ) + @staticmethod def _host_row_to_dict(row: sqlite3.Row) -> dict: return { @@ -462,3 +749,53 @@ class Registry: "routing_policy": json.loads(row["routing_policy_json"]), "updated_at": row["updated_at"], } + + @staticmethod + def _benchmark_row_to_dict(row: sqlite3.Row) -> dict: + return { + "benchmark_id": row["benchmark_id"], + "service_id": row["service_id"], + "asset_id": row["asset_id"], + "workload": row["workload"], + "observed_at": row["observed_at"], + "results": json.loads(row["results_json"]), + } + + +def _tokenize_text(value: str) -> set[str]: + return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token} + + +def _tokenize_tasks(tasks: list[str]) -> set[str]: + return _tokenize_text(" ".join(tasks)) + + +def _overlap_score(task_tokens: set[str], candidate_tokens: set[str]) -> float: + if not task_tokens or not candidate_tokens: + return 0.0 + overlap = len(task_tokens & candidate_tokens) / max(1, len(task_tokens)) + return min(1.0, overlap) + + +def _benchmark_quality_score(results: dict) -> float: + if not results: + return 0.0 + quality = 0.0 + tokens_per_sec = results.get("tokens_per_sec") + ttft_ms = results.get("ttft_ms") + pass_rate = results.get("pass_rate") + quality_score = results.get("quality_score") + if isinstance(quality_score, (int, float)): + quality = max(quality, max(0.0, min(1.0, float(quality_score)))) + if isinstance(pass_rate, (int, float)): + quality = max(quality, max(0.0, min(1.0, float(pass_rate)))) + if isinstance(tokens_per_sec, (int, float)): + quality += min(0.35, float(tokens_per_sec) / 100.0) + if isinstance(ttft_ms, (int, float)): + if float(ttft_ms) <= 1000: + quality += 0.25 + elif float(ttft_ms) <= 2500: + quality += 0.15 + else: + quality += 0.05 + return min(1.0, quality) diff --git a/src/geniehive_control/request_policy.py b/src/geniehive_control/request_policy.py new file mode 100644 index 0000000..e6c275a --- /dev/null +++ b/src/geniehive_control/request_policy.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import copy +import re +from typing import Any + + +def deep_merge_defaults(payload: dict[str, Any], defaults: dict[str, Any]) -> dict[str, Any]: + merged = copy.deepcopy(payload) + for key, value in defaults.items(): + if key not in merged: + merged[key] = copy.deepcopy(value) + continue + if isinstance(merged[key], dict) and isinstance(value, dict): + merged[key] = deep_merge_defaults(merged[key], value) + return merged + + +def deep_merge_override(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + merged = copy.deepcopy(base) + for key, value in override.items(): + if isinstance(merged.get(key), dict) and isinstance(value, dict): + merged[key] = deep_merge_override(merged[key], value) + else: + merged[key] = copy.deepcopy(value) + return merged + + +def apply_system_prompt(messages: list[dict[str, Any]], prompt: str, position: str) -> list[dict[str, Any]]: + if any(message.get("role") == "system" and message.get("content") == prompt for message in messages): + return messages + + injected = {"role": "system", "content": prompt} + if position == "append": + return [*messages, injected] + if position == "replace": + updated = list(messages) + for index, message in enumerate(updated): + if message.get("role") == "system": + updated[index] = injected + return updated + return [injected, *updated] + return [injected, *messages] + + +def apply_request_policy(body: dict[str, Any], policy: dict[str, Any] | None) -> dict[str, Any]: + if not policy: + return dict(body) + + shaped = deep_merge_defaults(body, policy.get("body_defaults", {}) or {}) + system_prompt = policy.get("system_prompt") + if not system_prompt: + return shaped + + messages = shaped.get("messages") + if not isinstance(messages, list): + return shaped + + normalized_messages = [message for message in messages if isinstance(message, dict)] + shaped["messages"] = apply_system_prompt( + normalized_messages, + system_prompt, + str(policy.get("system_prompt_position") or "prepend"), + ) + return shaped + + +def merge_request_policies(*policies: dict[str, Any] | None) -> dict[str, Any] | None: + combined: dict[str, Any] = {} + for policy in policies: + if not policy: + continue + body_defaults = policy.get("body_defaults") or {} + if body_defaults: + combined["body_defaults"] = deep_merge_override(combined.get("body_defaults", {}), body_defaults) + if policy.get("system_prompt"): + combined["system_prompt"] = policy["system_prompt"] + combined["system_prompt_position"] = policy.get("system_prompt_position", "prepend") + return combined or None + + +def select_target_asset(service: dict[str, Any], requested_model: str) -> dict[str, Any] | None: + assets = service.get("assets", []) + for asset in assets: + if asset.get("asset_id") == requested_model: + return asset + for asset in assets: + if asset.get("loaded") and asset.get("asset_id"): + return asset + for asset in assets: + if asset.get("asset_id"): + return asset + return None + + +def looks_like_qwen3(value: str) -> bool: + normalized = re.sub(r"[^a-z0-9]+", "", value.lower()) + return normalized.startswith("qwen3") + + +def infer_chat_request_policy(requested_model: str, service: dict[str, Any], asset: dict[str, Any] | None) -> dict[str, Any] | None: + identifiers = [requested_model, service.get("service_id", "")] + if asset is not None: + identifiers.append(asset.get("asset_id", "")) + identifiers.extend(asset_item.get("asset_id", "") for asset_item in service.get("assets", [])) + if any(looks_like_qwen3(identifier) for identifier in identifiers if identifier): + return { + "body_defaults": { + "chat_template_kwargs": { + "enable_thinking": False, + } + } + } + return None + + +def effective_chat_request_policy( + *, + requested_model: str, + service: dict[str, Any], + role: dict[str, Any] | None = None, + asset: dict[str, Any] | None = None, +) -> dict[str, Any] | None: + return merge_request_policies( + infer_chat_request_policy(requested_model, service, asset), + (asset or {}).get("request_policy"), + ((role or {}).get("prompt_policy") or {}).get("request_policy"), + ) diff --git a/tests/test_benchmark_runner.py b/tests/test_benchmark_runner.py new file mode 100644 index 0000000..a071bb0 --- /dev/null +++ b/tests/test_benchmark_runner.py @@ -0,0 +1,104 @@ +from geniehive_control.benchmark_runner import ChatBenchmarkCase, ChatBenchmarkWorkload, built_in_chat_workloads, run_chat_benchmark + + +def test_built_in_chat_workloads_exist() -> None: + workloads = built_in_chat_workloads() + + assert "chat.short_reasoning" in workloads + assert workloads["chat.short_reasoning"].cases + assert workloads["chat.short_reasoning"].chat_template_kwargs == {"enable_thinking": False} + + +def test_run_chat_benchmark_generates_report() -> None: + workload = ChatBenchmarkWorkload( + workload="chat.short_reasoning", + system_prompt="You are concise.", + cases=[ + ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly."), + ChatBenchmarkCase(name="case2", prompt="Explain latency versus throughput briefly."), + ], + ) + + def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict: + assert url.endswith("/v1/chat/completions") + assert payload["model"] == "general_assistant" + assert "chat_template_kwargs" not in payload + return { + "choices": [{"message": {"role": "assistant", "content": "Benchmark response."}}], + "usage": {"prompt_tokens": 20, "completion_tokens": 10}, + "timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0}, + } + + report = run_chat_benchmark( + base_url="http://127.0.0.1:8800", + api_key="change-me-client-key", + model="general_assistant", + workload=workload, + request_fn=fake_request, + observed_at=1775584000.0, + ) + + sample = report.samples[0] + assert report.source == "geniehive-benchmark-runner" + assert sample.workload == "chat.short_reasoning" + assert sample.results["case_count"] == 2 + assert sample.results["pass_rate"] == 1.0 + assert sample.results["response_rate"] == 1.0 + assert sample.results["empty_visible_response_rate"] == 0.0 + assert sample.results["tokens_per_sec"] == 25.0 + assert sample.observed_at == 1775584000.0 + + +def test_run_chat_benchmark_treats_reasoning_content_as_a_pass() -> None: + workload = ChatBenchmarkWorkload( + workload="chat.short_reasoning", + system_prompt="You are concise.", + cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")], + ) + + def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict: + return { + "choices": [{"message": {"role": "assistant", "content": "", "reasoning_content": "Reasoning only."}}], + "usage": {"prompt_tokens": 20, "completion_tokens": 10}, + "timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0}, + } + + report = run_chat_benchmark( + base_url="http://127.0.0.1:8800", + api_key="change-me-client-key", + model="general_assistant", + workload=workload, + request_fn=fake_request, + observed_at=1775584000.0, + ) + + assert report.samples[0].results["pass_rate"] == 1.0 + assert report.samples[0].results["empty_visible_response_rate"] == 0.0 + + +def test_run_chat_benchmark_includes_chat_template_kwargs_when_configured() -> None: + workload = ChatBenchmarkWorkload( + workload="chat.short_reasoning", + system_prompt="You are concise.", + chat_template_kwargs={"enable_thinking": False}, + cases=[ChatBenchmarkCase(name="case1", prompt="Explain route selection briefly.")], + ) + + def fake_request(url: str, headers: dict[str, str], payload: dict) -> dict: + assert payload["chat_template_kwargs"] == {"enable_thinking": False} + return { + "choices": [{"message": {"role": "assistant", "content": "Visible response."}}], + "usage": {"prompt_tokens": 20, "completion_tokens": 10}, + "timings": {"prompt_ms": 120.0, "predicted_per_second": 25.0}, + } + + report = run_chat_benchmark( + base_url="http://127.0.0.1:8800", + api_key="change-me-client-key", + model="general_assistant", + workload=workload, + request_fn=fake_request, + observed_at=1775584000.0, + ) + + assert report.samples[0].results["pass_rate"] == 1.0 diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py new file mode 100644 index 0000000..9676ab2 --- /dev/null +++ b/tests/test_benchmarks.py @@ -0,0 +1,66 @@ +from pathlib import Path + +from geniehive_control.benchmarks import load_benchmark_report, report_to_samples + + +def test_load_benchmark_report_and_generate_sample_ids(tmp_path: Path) -> None: + report_path = tmp_path / "bench.json" + report_path.write_text( + """ +{ + "report_id": "p40-short-reasoning", + "observed_at": 1775583000.0, + "source": "local-smoke", + "samples": [ + { + "service_id": "p40-box/chat/gpu1-secondary", + "asset_id": "Qwen3.5-9B-Q5_K_M", + "workload": "chat.short_reasoning", + "results": { + "ttft_ms": 900, + "tokens_per_sec": 30, + "quality_score": 0.9 + } + } + ] +} +""".strip(), + encoding="utf-8", + ) + + report = load_benchmark_report(report_path) + samples = report_to_samples(report) + + assert report.report_id == "p40-short-reasoning" + assert len(samples) == 1 + assert samples[0].benchmark_id.startswith("bench_") + assert samples[0].observed_at == 1775583000.0 + assert samples[0].results["source"] == "local-smoke" + + +def test_report_to_samples_preserves_explicit_sample_ids(tmp_path: Path) -> None: + report_path = tmp_path / "bench.json" + report_path.write_text( + """ +{ + "observed_at": 1775583000.0, + "samples": [ + { + "benchmark_id": "bench-explicit", + "service_id": "p40-box/chat/gpu0-primary", + "workload": "chat.concise_support", + "results": { + "tokens_per_sec": 24 + } + } + ] +} +""".strip(), + encoding="utf-8", + ) + + report = load_benchmark_report(report_path) + samples = report_to_samples(report) + + assert samples[0].benchmark_id == "bench-explicit" + assert samples[0].workload == "chat.concise_support" diff --git a/tests/test_control_chat.py b/tests/test_control_chat.py index b7f962b..d4a6225 100644 --- a/tests/test_control_chat.py +++ b/tests/test_control_chat.py @@ -158,6 +158,108 @@ def test_proxy_chat_completion_strips_reasoning_fields(tmp_path: Path) -> None: assert "reasoning" not in choice +def test_proxy_chat_completion_applies_inferred_qwen_request_defaults(tmp_path: Path) -> None: + registry = _build_registry(tmp_path) + + class _InspectingPoster: + async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse: + assert json["chat_template_kwargs"] == {"enable_thinking": False} + return _FakeResponse({"ok": True, "echo_model": json["model"]}) + + upstream = UpstreamClient(client=_InspectingPoster()) + + async def run() -> dict: + return await proxy_chat_completion( + { + "model": "mentor", + "messages": [{"role": "user", "content": "hello"}], + }, + registry=registry, + upstream=upstream, + ) + + result = asyncio.run(run()) + assert result["echo_model"] == "qwen3-8b-q4km" + + +def test_proxy_chat_completion_preserves_explicit_template_kwargs(tmp_path: Path) -> None: + registry = _build_registry(tmp_path) + + class _InspectingPoster: + async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse: + assert json["chat_template_kwargs"] == {"enable_thinking": True, "foo": "bar"} + return _FakeResponse({"ok": True, "echo_model": json["model"]}) + + upstream = UpstreamClient(client=_InspectingPoster()) + + async def run() -> dict: + return await proxy_chat_completion( + { + "model": "mentor", + "messages": [{"role": "user", "content": "hello"}], + "chat_template_kwargs": {"enable_thinking": True, "foo": "bar"}, + }, + registry=registry, + upstream=upstream, + ) + + result = asyncio.run(run()) + assert result["echo_model"] == "qwen3-8b-q4km" + + +def test_proxy_chat_completion_applies_asset_request_policy(tmp_path: Path) -> None: + registry = Registry(tmp_path / "geniehive.sqlite3") + registry.register_host( + HostRegistration( + host_id="atlas-01", + address="192.168.1.101", + services=[ + RegisteredService( + service_id="atlas-01/chat/custom-model", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18091", + assets=[ + { + "asset_id": "custom-model-v1", + "loaded": True, + "request_policy": { + "body_defaults": { + "temperature": 0.2, + "chat_template_kwargs": {"custom_flag": "yes"}, + } + }, + } + ], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 900}, + ) + ], + ) + ) + + class _InspectingPoster: + async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse: + assert json["temperature"] == 0.2 + assert json["chat_template_kwargs"] == {"custom_flag": "yes"} + return _FakeResponse({"ok": True, "echo_model": json["model"]}) + + upstream = UpstreamClient(client=_InspectingPoster()) + + async def run() -> dict: + return await proxy_chat_completion( + { + "model": "custom-model-v1", + "messages": [{"role": "user", "content": "hello"}], + }, + registry=registry, + upstream=upstream, + ) + + result = asyncio.run(run()) + assert result["echo_model"] == "custom-model-v1" + + def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None: registry = _build_registry(tmp_path) upstream = UpstreamClient(client=_FakePoster()) diff --git a/tests/test_control_registry.py b/tests/test_control_registry.py index d74185a..4ea3048 100644 --- a/tests/test_control_registry.py +++ b/tests/test_control_registry.py @@ -1,7 +1,7 @@ from pathlib import Path from geniehive_control.main import create_app -from geniehive_control.models import HostHeartbeat, HostRegistration, RegisteredService, RoleProfile +from geniehive_control.models import BenchmarkSample, HostHeartbeat, HostRegistration, RegisteredService, RoleProfile, RouteMatchRequest from geniehive_control.registry import Registry @@ -133,9 +133,11 @@ def test_registry_persists_roles_and_resolves_direct_and_role_routes(tmp_path: P mentor = next(item for item in models if item["id"] == "mentor") assert mentor["geniehive"]["route_type"] == "role" assert mentor["geniehive"]["offload_hint"]["suitability"] == "good_for_low_complexity" + assert mentor["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False asset = next(item for item in models if item["id"] == "qwen3-8b-q4km") assert asset["geniehive"]["route_type"] == "asset" assert asset["geniehive"]["offload_hint"]["recommended_for"] == "lower-complexity offload" + assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["enable_thinking"] is False def test_control_app_exposes_expected_routes() -> None: @@ -147,6 +149,220 @@ def test_control_app_exposes_expected_routes() -> None: assert "/v1/nodes/heartbeat" in paths assert "/v1/cluster/hosts" in paths assert "/v1/cluster/services" in paths + assert "/v1/cluster/benchmarks" in paths assert "/v1/cluster/roles" in paths assert "/v1/cluster/health" in paths assert "/v1/cluster/routes/resolve" in paths + assert "/v1/cluster/routes/match" in paths + + +def test_registry_can_rank_routes_for_task_statements(tmp_path: Path) -> None: + db_path = tmp_path / "geniehive.sqlite3" + registry = Registry(db_path) + + registry.register_host( + HostRegistration( + host_id="atlas-01", + address="192.168.1.101", + services=[ + RegisteredService( + service_id="atlas-01/chat/qwen-reasoner", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18091", + assets=[{"asset_id": "qwen3.5-9b", "loaded": True}], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 1100, "tokens_per_sec": 28}, + ), + RegisteredService( + service_id="atlas-01/chat/rocket-background", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18093", + assets=[{"asset_id": "rocket-3b", "loaded": True}], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 4200, "tokens_per_sec": 7}, + ), + ], + ) + ) + registry.upsert_roles( + [ + RoleProfile( + role_id="general_assistant", + display_name="General Assistant", + description="Fast general technical assistant for reasoning and question answering.", + operation="chat", + modality="text", + routing_policy={"preferred_families": ["qwen3.5", "qwen"]}, + ), + RoleProfile( + role_id="background_summarizer", + display_name="Background Summarizer", + description="Slow fallback summarizer for lower-priority background work.", + operation="chat", + modality="text", + routing_policy={"preferred_families": ["rocket"]}, + ), + ] + ) + + result = registry.match_routes( + RouteMatchRequest( + task="fast technical reasoning for an interactive assistant", + kind="chat", + modality="text", + limit=4, + ) + ) + + assert result["task_count"] == 1 + assert result["candidates"] + top = result["candidates"][0] + assert top["candidate_type"] == "role" + assert top["candidate_id"] == "general_assistant" + assert top["service"]["service_id"] == "atlas-01/chat/qwen-reasoner" + assert top["signals"]["preferred_family_match"] == 1.0 + + +def test_registry_match_can_include_direct_services(tmp_path: Path) -> None: + db_path = tmp_path / "geniehive.sqlite3" + registry = Registry(db_path) + registry.register_host( + HostRegistration( + host_id="atlas-01", + address="192.168.1.101", + services=[ + RegisteredService( + service_id="atlas-01/chat/qwen3-8b", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18091", + assets=[{"asset_id": "qwen3-8b-q4km", "loaded": True}], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 900, "tokens_per_sec": 40}, + ) + ], + ) + ) + + result = registry.match_routes( + RouteMatchRequest( + task="qwen model for quick chat", + kind="chat", + include_direct_services=True, + limit=4, + ) + ) + + direct = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service") + assert direct["candidate_id"] == "atlas-01/chat/qwen3-8b" + assert direct["service"]["assets"][0]["asset_id"] == "qwen3-8b-q4km" + + +def test_registry_persists_benchmark_samples_and_uses_them_for_matching(tmp_path: Path) -> None: + db_path = tmp_path / "geniehive.sqlite3" + registry = Registry(db_path) + registry.register_host( + HostRegistration( + host_id="atlas-01", + address="192.168.1.101", + services=[ + RegisteredService( + service_id="atlas-01/chat/qwen-fast", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18091", + assets=[{"asset_id": "qwen3.5-9b", "loaded": True}], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 1500, "tokens_per_sec": 22}, + ), + RegisteredService( + service_id="atlas-01/chat/rocket-slow", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18093", + assets=[{"asset_id": "rocket-3b", "loaded": True}], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 1200, "tokens_per_sec": 10}, + ), + ], + ) + ) + registry.upsert_benchmark_samples( + [ + BenchmarkSample( + benchmark_id="bench-qwen-1", + service_id="atlas-01/chat/qwen-fast", + asset_id="qwen3.5-9b", + workload="chat.short_reasoning", + observed_at=1000.0, + results={"tokens_per_sec": 30, "ttft_ms": 900, "quality_score": 0.9}, + ), + BenchmarkSample( + benchmark_id="bench-rocket-1", + service_id="atlas-01/chat/rocket-slow", + asset_id="rocket-3b", + workload="chat.short_reasoning", + observed_at=1000.0, + results={"tokens_per_sec": 9, "ttft_ms": 1900, "quality_score": 0.4}, + ), + ] + ) + + samples = registry.list_benchmark_samples(service_id="atlas-01/chat/qwen-fast") + assert len(samples) == 1 + assert samples[0]["benchmark_id"] == "bench-qwen-1" + + result = registry.match_routes( + RouteMatchRequest( + task="fast short reasoning for chat responses", + workloads=["chat.short_reasoning"], + kind="chat", + include_direct_services=True, + limit=4, + ) + ) + + top_service = next(candidate for candidate in result["candidates"] if candidate["candidate_type"] == "service") + assert top_service["candidate_id"] == "atlas-01/chat/qwen-fast" + assert top_service["signals"]["benchmark_match_count"] == 1 + assert top_service["signals"]["best_workload_overlap"] == 1.0 + + +def test_registry_exposes_asset_request_policy_in_model_metadata(tmp_path: Path) -> None: + db_path = tmp_path / "geniehive.sqlite3" + registry = Registry(db_path) + registry.register_host( + HostRegistration( + host_id="atlas-01", + address="192.168.1.101", + services=[ + RegisteredService( + service_id="atlas-01/chat/custom-model", + host_id="atlas-01", + kind="chat", + endpoint="http://192.168.1.101:18091", + assets=[ + { + "asset_id": "custom-model-v1", + "loaded": True, + "request_policy": { + "body_defaults": { + "temperature": 0.2, + "chat_template_kwargs": {"custom_flag": "yes"}, + } + }, + } + ], + state={"health": "healthy", "load_state": "loaded", "accept_requests": True}, + observed={"p50_latency_ms": 900}, + ) + ], + ) + ) + + models = registry.list_client_models() + asset = next(item for item in models if item["id"] == "custom-model-v1") + assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["temperature"] == 0.2 + assert asset["geniehive"]["effective_request_policy"]["body_defaults"]["chat_template_kwargs"]["custom_flag"] == "yes"