Add opt-in request audit logging

This commit is contained in:
welsberr 2026-04-29 14:56:06 -04:00
parent 960f12f92b
commit 9a1a6f49af
5 changed files with 616 additions and 10 deletions

View File

@ -56,10 +56,11 @@ python -m pytest -q tests
Expected current result at baseline: all tests pass.
Current verification result after adding the Foundation roadmap, config profile
scaffold, named client key storage, opt-in named auth, and admin key endpoints:
scaffold, named client key storage, opt-in named auth, admin key endpoints, and
request audit logging:
```text
58 passed
61 passed
```
## Known Constraints

View File

@ -185,6 +185,10 @@ Acceptance:
Goal: make production requests attributable without storing prompt or completion
content.
Status: implemented for chat, embeddings, and transcription request wrappers.
Audit logging is disabled by default and enabled by `audit.enabled`. Admin audit
read endpoints are only mounted when `admin_api.enabled` is true.
Tasks:
- Add request ID generation from `X-Request-Id` or UUID.

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import asyncio
import json
import os
import time
import uuid
from contextlib import asynccontextmanager, suppress
from pathlib import Path
@ -17,6 +19,7 @@ from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, Rou
from .probe import ServiceProber
from .roles import load_role_catalog
from .registry import Registry
from .routing import choose_upstream_model_id
from .upstream import UpstreamClient, UpstreamError
@ -70,6 +73,86 @@ def create_app(
if key != "key_hash"
}
def _request_id(request: Request) -> str:
return request.headers.get("X-Request-Id") or f"req_{uuid.uuid4().hex}"
def _client_context(request: Request):
return getattr(request.state, "client_context", None)
def _route_audit_metadata(reg: Registry, requested_model: str | None, *, kind: str) -> dict:
if not requested_model:
return {
"requested_model": None,
"resolved_service_id": None,
"resolved_host_id": None,
"upstream_model": None,
"provider_kind": None,
}
resolved = reg.resolve_route(requested_model, kind=kind)
service = resolved.get("service") if resolved else None
if not service:
return {
"requested_model": requested_model,
"resolved_service_id": None,
"resolved_host_id": None,
"upstream_model": None,
"provider_kind": None,
}
return {
"requested_model": requested_model,
"resolved_service_id": service.get("service_id"),
"resolved_host_id": service.get("host_id"),
"upstream_model": choose_upstream_model_id(requested_model, service),
"provider_kind": service.get("protocol"),
}
def _usage_from_response(response: object) -> dict[str, int | None]:
usage = response.get("usage", {}) if isinstance(response, dict) else {}
return {
"prompt_tokens": usage.get("prompt_tokens") if isinstance(usage, dict) else None,
"completion_tokens": usage.get("completion_tokens") if isinstance(usage, dict) else None,
"total_tokens": usage.get("total_tokens") if isinstance(usage, dict) else None,
}
def _audit_request(
request: Request,
*,
request_id: str,
operation: str,
route_metadata: dict,
started_at: float,
status_code: int,
success: bool,
response: object | None = None,
error_type: str | None = None,
input_bytes: int | None = None,
output_bytes: int | None = None,
) -> None:
if not cfg.audit.enabled:
return
context = _client_context(request)
usage = _usage_from_response(response)
request.app.state.registry.record_request_audit(
request_id=request_id,
key_id=getattr(context, "key_id", None),
principal_type=getattr(context, "principal_type", None),
principal_ref=getattr(context, "principal_ref", None),
operation=operation,
requested_model=route_metadata.get("requested_model"),
resolved_service_id=route_metadata.get("resolved_service_id"),
resolved_host_id=route_metadata.get("resolved_host_id"),
upstream_model=route_metadata.get("upstream_model"),
provider_kind=route_metadata.get("provider_kind"),
started_at=started_at,
finished_at=time.time(),
status_code=status_code,
success=success,
error_type=error_type,
input_bytes=input_bytes,
output_bytes=output_bytes,
**usage,
)
if cfg.admin_api.enabled:
@app.post("/v1/admin/client-keys")
async def create_client_key(request: Request, _=Depends(require_admin_auth)) -> dict:
@ -126,6 +209,41 @@ def create_app(
return JSONResponse(status_code=404, content={"error": "unknown_client_key", "key_id": key_id})
return {"status": "ok", "client_key": _public_client_key(updated)}
@app.get("/v1/admin/audit/requests")
async def list_audit_requests(
request: Request,
key_id: str | None = None,
principal_ref: str | None = None,
operation: str | None = None,
model: str | None = None,
success: bool | None = None,
limit: int = 100,
_=Depends(require_admin_auth),
) -> dict:
if not cfg.audit.enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="audit logging is not enabled",
)
rows = request.app.state.registry.list_request_audit(
key_id=key_id,
principal_ref=principal_ref,
operation=operation,
model=model,
success=success,
limit=limit,
)
return {"object": "list", "data": rows}
@app.get("/v1/admin/audit/summary")
async def audit_summary(request: Request, _=Depends(require_admin_auth)) -> dict:
if not cfg.audit.enabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="audit logging is not enabled",
)
return {"object": "list", "data": request.app.state.registry.request_audit_summary()}
@app.post("/v1/nodes/register")
async def register_node(request: Request, _=Depends(require_node_auth)) -> dict:
payload = await request.json()
@ -155,45 +273,142 @@ def create_app(
body = await request.json()
reg: Registry = request.app.state.registry
up: UpstreamClient = request.app.state.upstream
request_id = _request_id(request)
started_at = time.time()
route_metadata = _route_audit_metadata(reg, body.get("model"), kind="chat")
input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8"))
try:
if body.get("stream"):
# Resolve route eagerly so ProxyError is raised before streaming starts.
service, upstream_body = _prepare_chat_upstream(body, registry=reg)
_audit_request(
request,
request_id=request_id,
operation="chat",
route_metadata=route_metadata,
started_at=started_at,
status_code=200,
success=True,
input_bytes=input_bytes,
)
return StreamingResponse(
stream_chat_completion(service, upstream_body, upstream=up),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "X-Request-Id": request_id},
)
return await proxy_chat_completion(body, registry=reg, upstream=up)
response = await proxy_chat_completion(body, registry=reg, upstream=up)
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
_audit_request(
request,
request_id=request_id,
operation="chat",
route_metadata=route_metadata,
started_at=started_at,
status_code=200,
success=True,
response=response,
input_bytes=input_bytes,
output_bytes=output_bytes,
)
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
except ProxyError as exc:
_audit_request(
request,
request_id=request_id,
operation="chat",
route_metadata=route_metadata,
started_at=started_at,
status_code=exc.status_code,
success=False,
error_type="proxy_error",
input_bytes=input_bytes,
)
return JSONResponse(
status_code=exc.status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "chat_proxy_error"}},
headers={"X-Request-Id": request_id},
)
except UpstreamError as exc:
status_code = exc.status_code or 502
_audit_request(
request,
request_id=request_id,
operation="chat",
route_metadata=route_metadata,
started_at=started_at,
status_code=status_code,
success=False,
error_type="upstream_error",
input_bytes=input_bytes,
)
return JSONResponse(
status_code=exc.status_code or 502,
status_code=status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
headers={"X-Request-Id": request_id},
)
@app.post("/v1/embeddings")
async def embeddings(request: Request, _=Depends(require_client_auth)):
body = await request.json()
reg: Registry = request.app.state.registry
request_id = _request_id(request)
started_at = time.time()
route_metadata = _route_audit_metadata(reg, body.get("model"), kind="embeddings")
input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8"))
try:
return await proxy_embeddings(
response = await proxy_embeddings(
body,
registry=request.app.state.registry,
registry=reg,
upstream=request.app.state.upstream,
)
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
_audit_request(
request,
request_id=request_id,
operation="embeddings",
route_metadata=route_metadata,
started_at=started_at,
status_code=200,
success=True,
response=response,
input_bytes=input_bytes,
output_bytes=output_bytes,
)
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
except ProxyError as exc:
_audit_request(
request,
request_id=request_id,
operation="embeddings",
route_metadata=route_metadata,
started_at=started_at,
status_code=exc.status_code,
success=False,
error_type="proxy_error",
input_bytes=input_bytes,
)
return JSONResponse(
status_code=exc.status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "embeddings_proxy_error"}},
headers={"X-Request-Id": request_id},
)
except UpstreamError as exc:
status_code = exc.status_code or 502
_audit_request(
request,
request_id=request_id,
operation="embeddings",
route_metadata=route_metadata,
started_at=started_at,
status_code=status_code,
success=False,
error_type="upstream_error",
input_bytes=input_bytes,
)
return JSONResponse(
status_code=exc.status_code or 502,
status_code=status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
headers={"X-Request-Id": request_id},
)
@app.post("/v1/audio/transcriptions")
@ -207,8 +422,11 @@ def create_app(
temperature: float | None = Form(None),
_=Depends(require_client_auth),
):
request_id = _request_id(request)
started_at = time.time()
route_metadata = _route_audit_metadata(request.app.state.registry, model, kind="transcription")
try:
return await proxy_transcription(
response = await proxy_transcription(
model=model,
file=file,
language=language,
@ -218,15 +436,51 @@ def create_app(
registry=request.app.state.registry,
upstream=request.app.state.upstream,
)
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
_audit_request(
request,
request_id=request_id,
operation="transcription",
route_metadata=route_metadata,
started_at=started_at,
status_code=200,
success=True,
response=response,
output_bytes=output_bytes,
)
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
except ProxyError as exc:
_audit_request(
request,
request_id=request_id,
operation="transcription",
route_metadata=route_metadata,
started_at=started_at,
status_code=exc.status_code,
success=False,
error_type="proxy_error",
)
return JSONResponse(
status_code=exc.status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "transcription_proxy_error"}},
headers={"X-Request-Id": request_id},
)
except UpstreamError as exc:
status_code = exc.status_code or 502
_audit_request(
request,
request_id=request_id,
operation="transcription",
route_metadata=route_metadata,
started_at=started_at,
status_code=status_code,
success=False,
error_type="upstream_error",
)
return JSONResponse(
status_code=exc.status_code or 502,
status_code=status_code,
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
headers={"X-Request-Id": request_id},
)
@app.get("/v1/cluster/services")

View File

@ -95,6 +95,32 @@ class Registry:
last_used_at REAL,
notes TEXT
);
CREATE TABLE IF NOT EXISTS request_audit_log (
request_id TEXT PRIMARY KEY,
key_id TEXT,
principal_type TEXT,
principal_ref TEXT,
operation TEXT NOT NULL,
requested_model TEXT,
resolved_service_id TEXT,
resolved_host_id TEXT,
upstream_model TEXT,
provider_kind TEXT,
started_at REAL NOT NULL,
finished_at REAL NOT NULL,
duration_ms REAL NOT NULL,
status_code INTEGER NOT NULL,
success INTEGER NOT NULL,
error_type TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
estimated_cost_cents REAL,
input_bytes INTEGER,
output_bytes INTEGER,
metadata_json TEXT NOT NULL DEFAULT '{}'
);
"""
)
@ -390,6 +416,145 @@ class Registry:
(now, now, key_id),
)
def record_request_audit(
self,
*,
request_id: str,
key_id: str | None,
principal_type: str | None,
principal_ref: str | None,
operation: str,
requested_model: str | None,
resolved_service_id: str | None,
resolved_host_id: str | None,
upstream_model: str | None,
provider_kind: str | None,
started_at: float,
finished_at: float,
status_code: int,
success: bool,
error_type: str | None = None,
prompt_tokens: int | None = None,
completion_tokens: int | None = None,
total_tokens: int | None = None,
estimated_cost_cents: float | None = None,
input_bytes: int | None = None,
output_bytes: int | None = None,
metadata: dict | None = None,
) -> dict:
duration_ms = max(0.0, (finished_at - started_at) * 1000.0)
with self._connect() as conn:
conn.execute(
"""
INSERT INTO request_audit_log (
request_id, key_id, principal_type, principal_ref,
operation, requested_model, resolved_service_id,
resolved_host_id, upstream_model, provider_kind,
started_at, finished_at, duration_ms, status_code, success,
error_type, prompt_tokens, completion_tokens, total_tokens,
estimated_cost_cents, input_bytes, output_bytes,
metadata_json
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
request_id,
key_id,
principal_type,
principal_ref,
operation,
requested_model,
resolved_service_id,
resolved_host_id,
upstream_model,
provider_kind,
started_at,
finished_at,
duration_ms,
status_code,
1 if success else 0,
error_type,
prompt_tokens,
completion_tokens,
total_tokens,
estimated_cost_cents,
input_bytes,
output_bytes,
_json_dumps(metadata or {}),
),
)
row = self.get_request_audit(request_id)
if row is None:
raise RuntimeError(f"created audit row {request_id!r} could not be loaded")
return row
def get_request_audit(self, request_id: str) -> dict | None:
with self._connect() as conn:
row = conn.execute(
"SELECT * FROM request_audit_log WHERE request_id = ?",
(request_id,),
).fetchone()
return self._request_audit_row_to_dict(row) if row is not None else None
def list_request_audit(
self,
*,
key_id: str | None = None,
principal_ref: str | None = None,
operation: str | None = None,
model: str | None = None,
success: bool | None = None,
limit: int = 100,
) -> list[dict]:
query = "SELECT * FROM request_audit_log"
clauses = []
params: list[object] = []
if key_id:
clauses.append("key_id = ?")
params.append(key_id)
if principal_ref:
clauses.append("principal_ref = ?")
params.append(principal_ref)
if operation:
clauses.append("operation = ?")
params.append(operation)
if model:
clauses.append("requested_model = ?")
params.append(model)
if success is not None:
clauses.append("success = ?")
params.append(1 if success else 0)
if clauses:
query += " WHERE " + " AND ".join(clauses)
query += " ORDER BY started_at DESC LIMIT ?"
params.append(max(1, min(limit, 1000)))
with self._connect() as conn:
rows = conn.execute(query, params).fetchall()
return [self._request_audit_row_to_dict(row) for row in rows]
def request_audit_summary(self) -> list[dict]:
with self._connect() as conn:
rows = conn.execute(
"""
SELECT
key_id,
principal_ref,
operation,
requested_model,
COUNT(*) AS request_count,
SUM(success) AS success_count,
SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) AS failure_count,
SUM(COALESCE(prompt_tokens, 0)) AS prompt_tokens,
SUM(COALESCE(completion_tokens, 0)) AS completion_tokens,
SUM(COALESCE(total_tokens, 0)) AS total_tokens,
SUM(COALESCE(estimated_cost_cents, 0)) AS estimated_cost_cents
FROM request_audit_log
GROUP BY key_id, principal_ref, operation, requested_model
ORDER BY request_count DESC, requested_model
"""
).fetchall()
return [dict(row) for row in rows]
def list_client_models(self) -> list[dict]:
services = self.list_services()
roles = self.list_roles()
@ -927,6 +1092,34 @@ class Registry:
"notes": row["notes"],
}
@staticmethod
def _request_audit_row_to_dict(row: sqlite3.Row) -> dict:
return {
"request_id": row["request_id"],
"key_id": row["key_id"],
"principal_type": row["principal_type"],
"principal_ref": row["principal_ref"],
"operation": row["operation"],
"requested_model": row["requested_model"],
"resolved_service_id": row["resolved_service_id"],
"resolved_host_id": row["resolved_host_id"],
"upstream_model": row["upstream_model"],
"provider_kind": row["provider_kind"],
"started_at": row["started_at"],
"finished_at": row["finished_at"],
"duration_ms": row["duration_ms"],
"status_code": row["status_code"],
"success": bool(row["success"]),
"error_type": row["error_type"],
"prompt_tokens": row["prompt_tokens"],
"completion_tokens": row["completion_tokens"],
"total_tokens": row["total_tokens"],
"estimated_cost_cents": row["estimated_cost_cents"],
"input_bytes": row["input_bytes"],
"output_bytes": row["output_bytes"],
"metadata": json.loads(row["metadata_json"]),
}
def _tokenize_text(value: str) -> set[str]:
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}

154
tests/test_control_audit.py Normal file
View File

@ -0,0 +1,154 @@
import json
from pathlib import Path
from fastapi.testclient import TestClient
from geniehive_control.main import create_app
from geniehive_control.models import HostRegistration, RegisteredService
from geniehive_control.upstream import UpstreamClient
class _FakeResponse:
def __init__(self, payload: dict, status_code: int = 200) -> None:
self._payload = payload
self.status_code = status_code
self.text = str(payload)
def json(self) -> dict:
return self._payload
class _UsagePoster:
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
return _FakeResponse(
{
"object": "chat.completion",
"model": json["model"],
"choices": [{"index": 0, "message": {"role": "assistant", "content": "done"}}],
"usage": {
"prompt_tokens": 7,
"completion_tokens": 3,
"total_tokens": 10,
},
}
)
def _write_audit_config(tmp_path: Path) -> Path:
config_path = tmp_path / "control.yaml"
config_path.write_text(
f"""
auth:
client_api_keys:
- audit-key
audit:
enabled: true
admin_api:
enabled: true
storage:
sqlite_path: "{tmp_path / 'geniehive.sqlite3'}"
"""
)
return config_path
def _register_chat_service(app) -> None:
app.state.registry.register_host(
HostRegistration(
host_id="atlas-01",
address="127.0.0.1",
services=[
RegisteredService(
service_id="atlas-01/chat/qwen",
host_id="atlas-01",
kind="chat",
protocol="openai",
endpoint="http://127.0.0.1:18091",
assets=[{"asset_id": "qwen-test", "loaded": True}],
state={"health": "healthy", "accept_requests": True},
observed={"p50_latency_ms": 100},
)
],
)
)
def test_successful_chat_request_is_audited_without_prompt_content(tmp_path: Path) -> None:
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
_register_chat_service(app)
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-success"},
json={
"model": "qwen-test",
"messages": [{"role": "user", "content": "private prompt text"}],
},
)
assert response.status_code == 200
assert response.headers["x-request-id"] == "req-test-success"
row = app.state.registry.get_request_audit("req-test-success")
assert row is not None
assert row["operation"] == "chat"
assert row["requested_model"] == "qwen-test"
assert row["resolved_service_id"] == "atlas-01/chat/qwen"
assert row["upstream_model"] == "qwen-test"
assert row["provider_kind"] == "openai"
assert row["success"] is True
assert row["status_code"] == 200
assert row["prompt_tokens"] == 7
assert row["completion_tokens"] == 3
assert row["total_tokens"] == 10
assert "private prompt text" not in json.dumps(row)
def test_failed_chat_route_is_audited(tmp_path: Path) -> None:
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
client = TestClient(app)
response = client.post(
"/v1/chat/completions",
headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-failure"},
json={
"model": "missing-model",
"messages": [{"role": "user", "content": "private failure prompt"}],
},
)
assert response.status_code == 404
assert response.headers["x-request-id"] == "req-test-failure"
row = app.state.registry.get_request_audit("req-test-failure")
assert row is not None
assert row["operation"] == "chat"
assert row["requested_model"] == "missing-model"
assert row["success"] is False
assert row["status_code"] == 404
assert row["error_type"] == "proxy_error"
assert "private failure prompt" not in json.dumps(row)
def test_admin_audit_endpoints_list_and_summarize_requests(tmp_path: Path) -> None:
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
_register_chat_service(app)
client = TestClient(app)
client.post(
"/v1/chat/completions",
headers={"X-Api-Key": "audit-key"},
json={"model": "qwen-test", "messages": [{"role": "user", "content": "hello"}]},
)
listed = client.get("/v1/admin/audit/requests", headers={"X-Api-Key": "audit-key"})
assert listed.status_code == 200
assert listed.json()["data"][0]["requested_model"] == "qwen-test"
summary = client.get("/v1/admin/audit/summary", headers={"X-Api-Key": "audit-key"})
assert summary.status_code == 200
summary_row = summary.json()["data"][0]
assert summary_row["requested_model"] == "qwen-test"
assert summary_row["request_count"] == 1
assert summary_row["success_count"] == 1
assert summary_row["total_tokens"] == 10