352 lines
11 KiB
Python
352 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
import yaml
|
|
|
|
|
|
def _write_gateway_config(path: Path) -> Path:
|
|
path.write_text(
|
|
yaml.safe_dump(
|
|
{
|
|
"version": 1,
|
|
"default_model": "writer",
|
|
"auth": {
|
|
"client_api_keys": ["client-secret"],
|
|
"node_api_keys": ["node-secret"],
|
|
},
|
|
"models": {
|
|
"writer": {
|
|
"type": "proxy",
|
|
"openai_model_name": "writer",
|
|
"proxy_url": "http://127.0.0.1:8012",
|
|
"defaults": {"temperature": 0.6},
|
|
},
|
|
"reviewer": {
|
|
"type": "discovered",
|
|
"openai_model_name": "reviewer",
|
|
"role": "reviewer",
|
|
"strategy": "round_robin",
|
|
},
|
|
"random-reviewer": {
|
|
"type": "discovered",
|
|
"openai_model_name": "random-reviewer",
|
|
"role": "reviewer",
|
|
"strategy": "random",
|
|
},
|
|
},
|
|
}
|
|
)
|
|
)
|
|
return path
|
|
|
|
|
|
def _create_gateway_app(tmp_path: Path):
|
|
from rolemesh_gateway.main import create_app
|
|
|
|
return create_app(
|
|
_write_gateway_config(tmp_path / "models.yaml"),
|
|
registry_path=tmp_path / "registry.json",
|
|
registry_stale_after_s=30.0,
|
|
)
|
|
|
|
|
|
async def _request(app, method: str, path: str, **kwargs) -> httpx.Response:
|
|
transport = httpx.ASGITransport(app=app)
|
|
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
|
|
return await client.request(method, path, **kwargs)
|
|
|
|
|
|
def test_models_requires_client_auth(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
async def fake_get_models(base_url):
|
|
return {"object": "list", "data": []}
|
|
|
|
app.state.upstream.get_models = fake_get_models
|
|
unauthorized = asyncio.run(_request(app, "GET", "/v1/models"))
|
|
assert unauthorized.status_code == 401
|
|
|
|
authorized = asyncio.run(
|
|
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
|
)
|
|
assert authorized.status_code == 200
|
|
body = authorized.json()
|
|
assert {item["id"] for item in body["data"]} == {"writer"}
|
|
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"reviewer", "random-reviewer"}
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_chat_completions_applies_defaults_for_proxy_model(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
calls = {}
|
|
|
|
async def fake_chat(base_url, payload):
|
|
calls["base_url"] = base_url
|
|
calls["payload"] = payload
|
|
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
|
|
|
app.state.upstream.chat_completions = fake_chat
|
|
|
|
response = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"x-api-key": "client-secret"},
|
|
json={
|
|
"model": "writer",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
)
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert calls["base_url"] == "http://127.0.0.1:8012"
|
|
assert calls["payload"]["temperature"] == 0.6
|
|
assert response.json()["choices"][0]["message"]["content"] == "ok"
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_discovered_model_without_registered_node_returns_503(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
response = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"x-api-key": "client-secret"},
|
|
json={
|
|
"model": "reviewer",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
)
|
|
)
|
|
|
|
assert response.status_code == 503
|
|
assert response.json()["error"]["code"] == "no_upstream"
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_node_registration_and_heartbeat_update_registry_state(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
register = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/nodes/register",
|
|
headers={"x-rolemesh-node-key": "node-secret"},
|
|
json={
|
|
"node_id": "node-a",
|
|
"base_url": "http://127.0.0.1:9001",
|
|
"roles": ["reviewer"],
|
|
"meta": {"gpu": "test"},
|
|
},
|
|
)
|
|
)
|
|
assert register.status_code == 200
|
|
|
|
heartbeat = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/nodes/heartbeat",
|
|
headers={"x-rolemesh-node-key": "node-secret"},
|
|
json={
|
|
"node_id": "node-a",
|
|
"timestamp": 123.0,
|
|
"status": {"healthy": True},
|
|
"metrics": [{"device": {"id": "gpu:0"}}],
|
|
},
|
|
)
|
|
)
|
|
|
|
assert heartbeat.status_code == 200
|
|
node = heartbeat.json()["node"]
|
|
assert node["status"]["healthy"] is True
|
|
assert node["status"]["timestamp"] == 123.0
|
|
assert node["status"]["metrics"] == [{"device": {"id": "gpu:0"}}]
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_models_filters_unhealthy_aliases_and_reports_them(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
|
|
async def fake_get_models(base_url):
|
|
if base_url == "http://127.0.0.1:8012":
|
|
raise Exception("unexpected")
|
|
return {"object": "list", "data": []}
|
|
|
|
from rolemesh_gateway.upstream import UpstreamError
|
|
|
|
async def wrapped_get_models(base_url):
|
|
if base_url == "http://127.0.0.1:8012":
|
|
raise UpstreamError("Upstream unreachable: boom")
|
|
return await fake_get_models(base_url)
|
|
|
|
app.state.upstream.get_models = wrapped_get_models
|
|
|
|
response = asyncio.run(
|
|
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["data"] == []
|
|
assert {item["alias"] for item in body["rolemesh"]["unavailable_models"]} == {"writer", "reviewer", "random-reviewer"}
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_ready_returns_503_when_default_model_is_unavailable(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
|
|
from rolemesh_gateway.upstream import UpstreamError
|
|
|
|
async def fake_get_models(base_url):
|
|
raise UpstreamError("Upstream unreachable: boom")
|
|
|
|
app.state.upstream.get_models = fake_get_models
|
|
|
|
response = asyncio.run(_request(app, "GET", "/ready"))
|
|
|
|
assert response.status_code == 503
|
|
body = response.json()
|
|
assert body["status"] == "not_ready"
|
|
assert body["default_model"] == "writer"
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_ready_returns_200_when_default_model_is_available(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
|
|
async def fake_get_models(base_url):
|
|
if base_url == "http://127.0.0.1:8012":
|
|
return {"object": "list", "data": []}
|
|
raise Exception("unexpected upstream probe")
|
|
|
|
from rolemesh_gateway.upstream import UpstreamError
|
|
|
|
async def wrapped_get_models(base_url):
|
|
if base_url == "http://127.0.0.1:8012":
|
|
return {"object": "list", "data": []}
|
|
raise UpstreamError("no reviewer nodes")
|
|
|
|
app.state.upstream.get_models = wrapped_get_models
|
|
|
|
response = asyncio.run(_request(app, "GET", "/ready"))
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["status"] == "ready"
|
|
assert body["default_model"] == "writer"
|
|
assert body["available_models"] == ["writer"]
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_stale_discovered_nodes_are_not_advertised_as_available(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
|
|
register = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/nodes/register",
|
|
headers={"x-rolemesh-node-key": "node-secret"},
|
|
json={
|
|
"node_id": "node-a",
|
|
"base_url": "http://127.0.0.1:9001",
|
|
"roles": ["reviewer"],
|
|
},
|
|
)
|
|
)
|
|
assert register.status_code == 200
|
|
|
|
node = app.state.registry.list_nodes(include_stale=True)[0]
|
|
node.last_seen = node.last_seen - 120
|
|
|
|
async def fake_get_models(base_url):
|
|
return {"object": "list", "data": []}
|
|
|
|
app.state.upstream.get_models = fake_get_models
|
|
|
|
models = asyncio.run(
|
|
_request(app, "GET", "/v1/models", headers={"x-api-key": "client-secret"})
|
|
)
|
|
ready = asyncio.run(_request(app, "GET", "/ready"))
|
|
|
|
models_body = models.json()
|
|
assert {item["id"] for item in models_body["data"]} == {"writer"}
|
|
assert models_body["rolemesh"]["registered_nodes"][0]["stale"] is True
|
|
unavailable = {item["alias"]: item for item in models_body["rolemesh"]["unavailable_models"]}
|
|
assert unavailable["reviewer"]["error"] == "no_fresh_registered_nodes"
|
|
assert unavailable["random-reviewer"]["error"] == "no_fresh_registered_nodes"
|
|
assert ready.status_code == 200
|
|
asyncio.run(app.state.upstream.close())
|
|
|
|
|
|
def test_discovered_model_uses_random_strategy_when_configured(tmp_path):
|
|
app = _create_gateway_app(tmp_path)
|
|
calls = {}
|
|
|
|
register_a = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/nodes/register",
|
|
headers={"x-rolemesh-node-key": "node-secret"},
|
|
json={
|
|
"node_id": "node-a",
|
|
"base_url": "http://127.0.0.1:9001",
|
|
"roles": ["reviewer"],
|
|
},
|
|
)
|
|
)
|
|
register_b = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/nodes/register",
|
|
headers={"x-rolemesh-node-key": "node-secret"},
|
|
json={
|
|
"node_id": "node-b",
|
|
"base_url": "http://127.0.0.1:9002",
|
|
"roles": ["reviewer"],
|
|
},
|
|
)
|
|
)
|
|
assert register_a.status_code == 200
|
|
assert register_b.status_code == 200
|
|
|
|
import random
|
|
|
|
original_choice = random.choice
|
|
random.choice = lambda candidates: candidates[-1]
|
|
|
|
async def fake_chat(base_url, payload):
|
|
calls["base_url"] = base_url
|
|
return {"id": "cmpl-1", "choices": [{"message": {"role": "assistant", "content": "ok"}}]}
|
|
|
|
app.state.upstream.chat_completions = fake_chat
|
|
app.state.upstream.get_models = lambda base_url: {"object": "list", "data": []}
|
|
|
|
try:
|
|
response = asyncio.run(
|
|
_request(
|
|
app,
|
|
"POST",
|
|
"/v1/chat/completions",
|
|
headers={"x-api-key": "client-secret"},
|
|
json={
|
|
"model": "random-reviewer",
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
},
|
|
)
|
|
)
|
|
finally:
|
|
random.choice = original_choice
|
|
|
|
assert response.status_code == 200
|
|
assert calls["base_url"] == "http://127.0.0.1:9002"
|
|
asyncio.run(app.state.upstream.close())
|