225 lines
7.2 KiB
Python
225 lines
7.2 KiB
Python
import asyncio
|
|
from pathlib import Path
|
|
|
|
from geniehive_control.chat import ProxyError, proxy_chat_completion, proxy_embeddings
|
|
from geniehive_control.models import HostRegistration, RegisteredService, RoleProfile
|
|
from geniehive_control.registry import Registry
|
|
from geniehive_control.upstream import UpstreamClient
|
|
|
|
|
|
class _FakeResponse:
|
|
def __init__(self, payload: dict, status_code: int = 200) -> None:
|
|
self._payload = payload
|
|
self.status_code = status_code
|
|
self.text = str(payload)
|
|
|
|
def json(self) -> dict:
|
|
return self._payload
|
|
|
|
|
|
class _FakePoster:
|
|
def __init__(self) -> None:
|
|
self.calls: list[dict] = []
|
|
|
|
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
|
self.calls.append({"url": url, "json": json, "headers": headers or {}})
|
|
return _FakeResponse({"ok": True, "echo_model": json["model"]})
|
|
|
|
|
|
def _build_registry(tmp_path: Path) -> Registry:
|
|
registry = Registry(tmp_path / "geniehive.sqlite3")
|
|
registry.register_host(
|
|
HostRegistration(
|
|
host_id="atlas-01",
|
|
address="192.168.1.101",
|
|
services=[
|
|
RegisteredService(
|
|
service_id="atlas-01/chat/qwen3-8b",
|
|
host_id="atlas-01",
|
|
kind="chat",
|
|
endpoint="http://192.168.1.101:18091",
|
|
assets=[{"asset_id": "qwen3-8b-q4km", "loaded": True}],
|
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
|
observed={"p50_latency_ms": 900},
|
|
),
|
|
RegisteredService(
|
|
service_id="atlas-01/embeddings/bge-small",
|
|
host_id="atlas-01",
|
|
kind="embeddings",
|
|
endpoint="http://192.168.1.101:18092",
|
|
assets=[{"asset_id": "bge-small-en", "loaded": True}],
|
|
state={"health": "healthy", "load_state": "loaded", "accept_requests": True},
|
|
observed={"p50_latency_ms": 120},
|
|
)
|
|
],
|
|
)
|
|
)
|
|
registry.upsert_roles(
|
|
[
|
|
RoleProfile(
|
|
role_id="mentor",
|
|
display_name="Mentor",
|
|
operation="chat",
|
|
modality="text",
|
|
routing_policy={"preferred_families": ["qwen3"]},
|
|
),
|
|
RoleProfile(
|
|
role_id="embedder",
|
|
display_name="Embedder",
|
|
operation="embeddings",
|
|
modality="text",
|
|
routing_policy={"require_loaded": True},
|
|
)
|
|
]
|
|
)
|
|
return registry
|
|
|
|
|
|
def test_proxy_chat_completion_rewrites_role_to_loaded_asset(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
fake = _FakePoster()
|
|
upstream = UpstreamClient(client=fake)
|
|
|
|
async def run() -> dict:
|
|
return await proxy_chat_completion(
|
|
{
|
|
"model": "mentor",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
result = asyncio.run(run())
|
|
assert result["ok"] is True
|
|
assert result["echo_model"] == "qwen3-8b-q4km"
|
|
assert fake.calls[0]["url"] == "http://192.168.1.101:18091/v1/chat/completions"
|
|
assert fake.calls[0]["json"]["model"] == "qwen3-8b-q4km"
|
|
|
|
|
|
def test_proxy_chat_completion_preserves_direct_asset_match(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
fake = _FakePoster()
|
|
upstream = UpstreamClient(client=fake)
|
|
|
|
async def run() -> dict:
|
|
return await proxy_chat_completion(
|
|
{
|
|
"model": "qwen3-8b-q4km",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
result = asyncio.run(run())
|
|
assert result["echo_model"] == "qwen3-8b-q4km"
|
|
|
|
|
|
def test_proxy_chat_completion_strips_reasoning_fields(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
|
|
class _ReasoningPoster:
|
|
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
|
return _FakeResponse(
|
|
{
|
|
"object": "chat.completion",
|
|
"model": json["model"],
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "GPU1 route is live.",
|
|
"reasoning_content": "hidden chain of thought",
|
|
},
|
|
"reasoning": {"tokens": 42},
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
upstream = UpstreamClient(client=_ReasoningPoster())
|
|
|
|
async def run() -> dict:
|
|
return await proxy_chat_completion(
|
|
{
|
|
"model": "mentor",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
result = asyncio.run(run())
|
|
choice = result["choices"][0]
|
|
assert choice["message"]["content"] == "GPU1 route is live."
|
|
assert "reasoning_content" not in choice["message"]
|
|
assert "reasoning" not in choice
|
|
|
|
|
|
def test_proxy_chat_completion_fails_for_unknown_model(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
upstream = UpstreamClient(client=_FakePoster())
|
|
|
|
async def run() -> None:
|
|
await proxy_chat_completion(
|
|
{
|
|
"model": "unknown-model",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
try:
|
|
asyncio.run(run())
|
|
except ProxyError as exc:
|
|
assert exc.status_code == 404
|
|
else:
|
|
raise AssertionError("expected ChatProxyError")
|
|
|
|
|
|
def test_proxy_embeddings_rewrites_role_to_loaded_asset(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
fake = _FakePoster()
|
|
upstream = UpstreamClient(client=fake)
|
|
|
|
async def run() -> dict:
|
|
return await proxy_embeddings(
|
|
{
|
|
"model": "embedder",
|
|
"input": "hello",
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
result = asyncio.run(run())
|
|
assert result["ok"] is True
|
|
assert result["echo_model"] == "bge-small-en"
|
|
assert fake.calls[0]["url"] == "http://192.168.1.101:18092/v1/embeddings"
|
|
assert fake.calls[0]["json"]["model"] == "bge-small-en"
|
|
|
|
|
|
def test_proxy_embeddings_fails_for_unknown_model(tmp_path: Path) -> None:
|
|
registry = _build_registry(tmp_path)
|
|
upstream = UpstreamClient(client=_FakePoster())
|
|
|
|
async def run() -> None:
|
|
await proxy_embeddings(
|
|
{
|
|
"model": "unknown-embedder",
|
|
"input": "hello",
|
|
},
|
|
registry=registry,
|
|
upstream=upstream,
|
|
)
|
|
|
|
try:
|
|
asyncio.run(run())
|
|
except ProxyError as exc:
|
|
assert exc.status_code == 404
|
|
else:
|
|
raise AssertionError("expected ProxyError")
|