Add opt-in request audit logging
This commit is contained in:
parent
960f12f92b
commit
9a1a6f49af
|
|
@ -56,10 +56,11 @@ python -m pytest -q tests
|
|||
Expected current result at baseline: all tests pass.
|
||||
|
||||
Current verification result after adding the Foundation roadmap, config profile
|
||||
scaffold, named client key storage, opt-in named auth, and admin key endpoints:
|
||||
scaffold, named client key storage, opt-in named auth, admin key endpoints, and
|
||||
request audit logging:
|
||||
|
||||
```text
|
||||
58 passed
|
||||
61 passed
|
||||
```
|
||||
|
||||
## Known Constraints
|
||||
|
|
|
|||
|
|
@ -185,6 +185,10 @@ Acceptance:
|
|||
Goal: make production requests attributable without storing prompt or completion
|
||||
content.
|
||||
|
||||
Status: implemented for chat, embeddings, and transcription request wrappers.
|
||||
Audit logging is disabled by default and enabled by `audit.enabled`. Admin audit
|
||||
read endpoints are only mounted when `admin_api.enabled` is true.
|
||||
|
||||
Tasks:
|
||||
|
||||
- Add request ID generation from `X-Request-Id` or UUID.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from pathlib import Path
|
||||
|
|
@ -17,6 +19,7 @@ from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, Rou
|
|||
from .probe import ServiceProber
|
||||
from .roles import load_role_catalog
|
||||
from .registry import Registry
|
||||
from .routing import choose_upstream_model_id
|
||||
from .upstream import UpstreamClient, UpstreamError
|
||||
|
||||
|
||||
|
|
@ -70,6 +73,86 @@ def create_app(
|
|||
if key != "key_hash"
|
||||
}
|
||||
|
||||
def _request_id(request: Request) -> str:
|
||||
return request.headers.get("X-Request-Id") or f"req_{uuid.uuid4().hex}"
|
||||
|
||||
def _client_context(request: Request):
|
||||
return getattr(request.state, "client_context", None)
|
||||
|
||||
def _route_audit_metadata(reg: Registry, requested_model: str | None, *, kind: str) -> dict:
|
||||
if not requested_model:
|
||||
return {
|
||||
"requested_model": None,
|
||||
"resolved_service_id": None,
|
||||
"resolved_host_id": None,
|
||||
"upstream_model": None,
|
||||
"provider_kind": None,
|
||||
}
|
||||
resolved = reg.resolve_route(requested_model, kind=kind)
|
||||
service = resolved.get("service") if resolved else None
|
||||
if not service:
|
||||
return {
|
||||
"requested_model": requested_model,
|
||||
"resolved_service_id": None,
|
||||
"resolved_host_id": None,
|
||||
"upstream_model": None,
|
||||
"provider_kind": None,
|
||||
}
|
||||
return {
|
||||
"requested_model": requested_model,
|
||||
"resolved_service_id": service.get("service_id"),
|
||||
"resolved_host_id": service.get("host_id"),
|
||||
"upstream_model": choose_upstream_model_id(requested_model, service),
|
||||
"provider_kind": service.get("protocol"),
|
||||
}
|
||||
|
||||
def _usage_from_response(response: object) -> dict[str, int | None]:
|
||||
usage = response.get("usage", {}) if isinstance(response, dict) else {}
|
||||
return {
|
||||
"prompt_tokens": usage.get("prompt_tokens") if isinstance(usage, dict) else None,
|
||||
"completion_tokens": usage.get("completion_tokens") if isinstance(usage, dict) else None,
|
||||
"total_tokens": usage.get("total_tokens") if isinstance(usage, dict) else None,
|
||||
}
|
||||
|
||||
def _audit_request(
|
||||
request: Request,
|
||||
*,
|
||||
request_id: str,
|
||||
operation: str,
|
||||
route_metadata: dict,
|
||||
started_at: float,
|
||||
status_code: int,
|
||||
success: bool,
|
||||
response: object | None = None,
|
||||
error_type: str | None = None,
|
||||
input_bytes: int | None = None,
|
||||
output_bytes: int | None = None,
|
||||
) -> None:
|
||||
if not cfg.audit.enabled:
|
||||
return
|
||||
context = _client_context(request)
|
||||
usage = _usage_from_response(response)
|
||||
request.app.state.registry.record_request_audit(
|
||||
request_id=request_id,
|
||||
key_id=getattr(context, "key_id", None),
|
||||
principal_type=getattr(context, "principal_type", None),
|
||||
principal_ref=getattr(context, "principal_ref", None),
|
||||
operation=operation,
|
||||
requested_model=route_metadata.get("requested_model"),
|
||||
resolved_service_id=route_metadata.get("resolved_service_id"),
|
||||
resolved_host_id=route_metadata.get("resolved_host_id"),
|
||||
upstream_model=route_metadata.get("upstream_model"),
|
||||
provider_kind=route_metadata.get("provider_kind"),
|
||||
started_at=started_at,
|
||||
finished_at=time.time(),
|
||||
status_code=status_code,
|
||||
success=success,
|
||||
error_type=error_type,
|
||||
input_bytes=input_bytes,
|
||||
output_bytes=output_bytes,
|
||||
**usage,
|
||||
)
|
||||
|
||||
if cfg.admin_api.enabled:
|
||||
@app.post("/v1/admin/client-keys")
|
||||
async def create_client_key(request: Request, _=Depends(require_admin_auth)) -> dict:
|
||||
|
|
@ -126,6 +209,41 @@ def create_app(
|
|||
return JSONResponse(status_code=404, content={"error": "unknown_client_key", "key_id": key_id})
|
||||
return {"status": "ok", "client_key": _public_client_key(updated)}
|
||||
|
||||
@app.get("/v1/admin/audit/requests")
|
||||
async def list_audit_requests(
|
||||
request: Request,
|
||||
key_id: str | None = None,
|
||||
principal_ref: str | None = None,
|
||||
operation: str | None = None,
|
||||
model: str | None = None,
|
||||
success: bool | None = None,
|
||||
limit: int = 100,
|
||||
_=Depends(require_admin_auth),
|
||||
) -> dict:
|
||||
if not cfg.audit.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="audit logging is not enabled",
|
||||
)
|
||||
rows = request.app.state.registry.list_request_audit(
|
||||
key_id=key_id,
|
||||
principal_ref=principal_ref,
|
||||
operation=operation,
|
||||
model=model,
|
||||
success=success,
|
||||
limit=limit,
|
||||
)
|
||||
return {"object": "list", "data": rows}
|
||||
|
||||
@app.get("/v1/admin/audit/summary")
|
||||
async def audit_summary(request: Request, _=Depends(require_admin_auth)) -> dict:
|
||||
if not cfg.audit.enabled:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="audit logging is not enabled",
|
||||
)
|
||||
return {"object": "list", "data": request.app.state.registry.request_audit_summary()}
|
||||
|
||||
@app.post("/v1/nodes/register")
|
||||
async def register_node(request: Request, _=Depends(require_node_auth)) -> dict:
|
||||
payload = await request.json()
|
||||
|
|
@ -155,45 +273,142 @@ def create_app(
|
|||
body = await request.json()
|
||||
reg: Registry = request.app.state.registry
|
||||
up: UpstreamClient = request.app.state.upstream
|
||||
request_id = _request_id(request)
|
||||
started_at = time.time()
|
||||
route_metadata = _route_audit_metadata(reg, body.get("model"), kind="chat")
|
||||
input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8"))
|
||||
try:
|
||||
if body.get("stream"):
|
||||
# Resolve route eagerly so ProxyError is raised before streaming starts.
|
||||
service, upstream_body = _prepare_chat_upstream(body, registry=reg)
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="chat",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=200,
|
||||
success=True,
|
||||
input_bytes=input_bytes,
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream_chat_completion(service, upstream_body, upstream=up),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "X-Request-Id": request_id},
|
||||
)
|
||||
return await proxy_chat_completion(body, registry=reg, upstream=up)
|
||||
response = await proxy_chat_completion(body, registry=reg, upstream=up)
|
||||
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="chat",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=200,
|
||||
success=True,
|
||||
response=response,
|
||||
input_bytes=input_bytes,
|
||||
output_bytes=output_bytes,
|
||||
)
|
||||
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
|
||||
except ProxyError as exc:
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="chat",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=exc.status_code,
|
||||
success=False,
|
||||
error_type="proxy_error",
|
||||
input_bytes=input_bytes,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "chat_proxy_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
except UpstreamError as exc:
|
||||
status_code = exc.status_code or 502
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="chat",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=status_code,
|
||||
success=False,
|
||||
error_type="upstream_error",
|
||||
input_bytes=input_bytes,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code or 502,
|
||||
status_code=status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
async def embeddings(request: Request, _=Depends(require_client_auth)):
|
||||
body = await request.json()
|
||||
reg: Registry = request.app.state.registry
|
||||
request_id = _request_id(request)
|
||||
started_at = time.time()
|
||||
route_metadata = _route_audit_metadata(reg, body.get("model"), kind="embeddings")
|
||||
input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8"))
|
||||
try:
|
||||
return await proxy_embeddings(
|
||||
response = await proxy_embeddings(
|
||||
body,
|
||||
registry=request.app.state.registry,
|
||||
registry=reg,
|
||||
upstream=request.app.state.upstream,
|
||||
)
|
||||
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="embeddings",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=200,
|
||||
success=True,
|
||||
response=response,
|
||||
input_bytes=input_bytes,
|
||||
output_bytes=output_bytes,
|
||||
)
|
||||
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
|
||||
except ProxyError as exc:
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="embeddings",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=exc.status_code,
|
||||
success=False,
|
||||
error_type="proxy_error",
|
||||
input_bytes=input_bytes,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "embeddings_proxy_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
except UpstreamError as exc:
|
||||
status_code = exc.status_code or 502
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="embeddings",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=status_code,
|
||||
success=False,
|
||||
error_type="upstream_error",
|
||||
input_bytes=input_bytes,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code or 502,
|
||||
status_code=status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
|
||||
@app.post("/v1/audio/transcriptions")
|
||||
|
|
@ -207,8 +422,11 @@ def create_app(
|
|||
temperature: float | None = Form(None),
|
||||
_=Depends(require_client_auth),
|
||||
):
|
||||
request_id = _request_id(request)
|
||||
started_at = time.time()
|
||||
route_metadata = _route_audit_metadata(request.app.state.registry, model, kind="transcription")
|
||||
try:
|
||||
return await proxy_transcription(
|
||||
response = await proxy_transcription(
|
||||
model=model,
|
||||
file=file,
|
||||
language=language,
|
||||
|
|
@ -218,15 +436,51 @@ def create_app(
|
|||
registry=request.app.state.registry,
|
||||
upstream=request.app.state.upstream,
|
||||
)
|
||||
output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="transcription",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=200,
|
||||
success=True,
|
||||
response=response,
|
||||
output_bytes=output_bytes,
|
||||
)
|
||||
return JSONResponse(content=response, headers={"X-Request-Id": request_id})
|
||||
except ProxyError as exc:
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="transcription",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=exc.status_code,
|
||||
success=False,
|
||||
error_type="proxy_error",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "transcription_proxy_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
except UpstreamError as exc:
|
||||
status_code = exc.status_code or 502
|
||||
_audit_request(
|
||||
request,
|
||||
request_id=request_id,
|
||||
operation="transcription",
|
||||
route_metadata=route_metadata,
|
||||
started_at=started_at,
|
||||
status_code=status_code,
|
||||
success=False,
|
||||
error_type="upstream_error",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code or 502,
|
||||
status_code=status_code,
|
||||
content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}},
|
||||
headers={"X-Request-Id": request_id},
|
||||
)
|
||||
|
||||
@app.get("/v1/cluster/services")
|
||||
|
|
|
|||
|
|
@ -95,6 +95,32 @@ class Registry:
|
|||
last_used_at REAL,
|
||||
notes TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS request_audit_log (
|
||||
request_id TEXT PRIMARY KEY,
|
||||
key_id TEXT,
|
||||
principal_type TEXT,
|
||||
principal_ref TEXT,
|
||||
operation TEXT NOT NULL,
|
||||
requested_model TEXT,
|
||||
resolved_service_id TEXT,
|
||||
resolved_host_id TEXT,
|
||||
upstream_model TEXT,
|
||||
provider_kind TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
finished_at REAL NOT NULL,
|
||||
duration_ms REAL NOT NULL,
|
||||
status_code INTEGER NOT NULL,
|
||||
success INTEGER NOT NULL,
|
||||
error_type TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
total_tokens INTEGER,
|
||||
estimated_cost_cents REAL,
|
||||
input_bytes INTEGER,
|
||||
output_bytes INTEGER,
|
||||
metadata_json TEXT NOT NULL DEFAULT '{}'
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -390,6 +416,145 @@ class Registry:
|
|||
(now, now, key_id),
|
||||
)
|
||||
|
||||
def record_request_audit(
|
||||
self,
|
||||
*,
|
||||
request_id: str,
|
||||
key_id: str | None,
|
||||
principal_type: str | None,
|
||||
principal_ref: str | None,
|
||||
operation: str,
|
||||
requested_model: str | None,
|
||||
resolved_service_id: str | None,
|
||||
resolved_host_id: str | None,
|
||||
upstream_model: str | None,
|
||||
provider_kind: str | None,
|
||||
started_at: float,
|
||||
finished_at: float,
|
||||
status_code: int,
|
||||
success: bool,
|
||||
error_type: str | None = None,
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
estimated_cost_cents: float | None = None,
|
||||
input_bytes: int | None = None,
|
||||
output_bytes: int | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
duration_ms = max(0.0, (finished_at - started_at) * 1000.0)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO request_audit_log (
|
||||
request_id, key_id, principal_type, principal_ref,
|
||||
operation, requested_model, resolved_service_id,
|
||||
resolved_host_id, upstream_model, provider_kind,
|
||||
started_at, finished_at, duration_ms, status_code, success,
|
||||
error_type, prompt_tokens, completion_tokens, total_tokens,
|
||||
estimated_cost_cents, input_bytes, output_bytes,
|
||||
metadata_json
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
request_id,
|
||||
key_id,
|
||||
principal_type,
|
||||
principal_ref,
|
||||
operation,
|
||||
requested_model,
|
||||
resolved_service_id,
|
||||
resolved_host_id,
|
||||
upstream_model,
|
||||
provider_kind,
|
||||
started_at,
|
||||
finished_at,
|
||||
duration_ms,
|
||||
status_code,
|
||||
1 if success else 0,
|
||||
error_type,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
estimated_cost_cents,
|
||||
input_bytes,
|
||||
output_bytes,
|
||||
_json_dumps(metadata or {}),
|
||||
),
|
||||
)
|
||||
row = self.get_request_audit(request_id)
|
||||
if row is None:
|
||||
raise RuntimeError(f"created audit row {request_id!r} could not be loaded")
|
||||
return row
|
||||
|
||||
def get_request_audit(self, request_id: str) -> dict | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT * FROM request_audit_log WHERE request_id = ?",
|
||||
(request_id,),
|
||||
).fetchone()
|
||||
return self._request_audit_row_to_dict(row) if row is not None else None
|
||||
|
||||
def list_request_audit(
|
||||
self,
|
||||
*,
|
||||
key_id: str | None = None,
|
||||
principal_ref: str | None = None,
|
||||
operation: str | None = None,
|
||||
model: str | None = None,
|
||||
success: bool | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[dict]:
|
||||
query = "SELECT * FROM request_audit_log"
|
||||
clauses = []
|
||||
params: list[object] = []
|
||||
if key_id:
|
||||
clauses.append("key_id = ?")
|
||||
params.append(key_id)
|
||||
if principal_ref:
|
||||
clauses.append("principal_ref = ?")
|
||||
params.append(principal_ref)
|
||||
if operation:
|
||||
clauses.append("operation = ?")
|
||||
params.append(operation)
|
||||
if model:
|
||||
clauses.append("requested_model = ?")
|
||||
params.append(model)
|
||||
if success is not None:
|
||||
clauses.append("success = ?")
|
||||
params.append(1 if success else 0)
|
||||
if clauses:
|
||||
query += " WHERE " + " AND ".join(clauses)
|
||||
query += " ORDER BY started_at DESC LIMIT ?"
|
||||
params.append(max(1, min(limit, 1000)))
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [self._request_audit_row_to_dict(row) for row in rows]
|
||||
|
||||
def request_audit_summary(self) -> list[dict]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT
|
||||
key_id,
|
||||
principal_ref,
|
||||
operation,
|
||||
requested_model,
|
||||
COUNT(*) AS request_count,
|
||||
SUM(success) AS success_count,
|
||||
SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) AS failure_count,
|
||||
SUM(COALESCE(prompt_tokens, 0)) AS prompt_tokens,
|
||||
SUM(COALESCE(completion_tokens, 0)) AS completion_tokens,
|
||||
SUM(COALESCE(total_tokens, 0)) AS total_tokens,
|
||||
SUM(COALESCE(estimated_cost_cents, 0)) AS estimated_cost_cents
|
||||
FROM request_audit_log
|
||||
GROUP BY key_id, principal_ref, operation, requested_model
|
||||
ORDER BY request_count DESC, requested_model
|
||||
"""
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def list_client_models(self) -> list[dict]:
|
||||
services = self.list_services()
|
||||
roles = self.list_roles()
|
||||
|
|
@ -927,6 +1092,34 @@ class Registry:
|
|||
"notes": row["notes"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _request_audit_row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {
|
||||
"request_id": row["request_id"],
|
||||
"key_id": row["key_id"],
|
||||
"principal_type": row["principal_type"],
|
||||
"principal_ref": row["principal_ref"],
|
||||
"operation": row["operation"],
|
||||
"requested_model": row["requested_model"],
|
||||
"resolved_service_id": row["resolved_service_id"],
|
||||
"resolved_host_id": row["resolved_host_id"],
|
||||
"upstream_model": row["upstream_model"],
|
||||
"provider_kind": row["provider_kind"],
|
||||
"started_at": row["started_at"],
|
||||
"finished_at": row["finished_at"],
|
||||
"duration_ms": row["duration_ms"],
|
||||
"status_code": row["status_code"],
|
||||
"success": bool(row["success"]),
|
||||
"error_type": row["error_type"],
|
||||
"prompt_tokens": row["prompt_tokens"],
|
||||
"completion_tokens": row["completion_tokens"],
|
||||
"total_tokens": row["total_tokens"],
|
||||
"estimated_cost_cents": row["estimated_cost_cents"],
|
||||
"input_bytes": row["input_bytes"],
|
||||
"output_bytes": row["output_bytes"],
|
||||
"metadata": json.loads(row["metadata_json"]),
|
||||
}
|
||||
|
||||
|
||||
def _tokenize_text(value: str) -> set[str]:
|
||||
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,154 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from geniehive_control.main import create_app
|
||||
from geniehive_control.models import HostRegistration, RegisteredService
|
||||
from geniehive_control.upstream import UpstreamClient
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, payload: dict, status_code: int = 200) -> None:
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = str(payload)
|
||||
|
||||
def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _UsagePoster:
|
||||
async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse:
|
||||
return _FakeResponse(
|
||||
{
|
||||
"object": "chat.completion",
|
||||
"model": json["model"],
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": "done"}}],
|
||||
"usage": {
|
||||
"prompt_tokens": 7,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 10,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _write_audit_config(tmp_path: Path) -> Path:
|
||||
config_path = tmp_path / "control.yaml"
|
||||
config_path.write_text(
|
||||
f"""
|
||||
auth:
|
||||
client_api_keys:
|
||||
- audit-key
|
||||
audit:
|
||||
enabled: true
|
||||
admin_api:
|
||||
enabled: true
|
||||
storage:
|
||||
sqlite_path: "{tmp_path / 'geniehive.sqlite3'}"
|
||||
"""
|
||||
)
|
||||
return config_path
|
||||
|
||||
|
||||
def _register_chat_service(app) -> None:
|
||||
app.state.registry.register_host(
|
||||
HostRegistration(
|
||||
host_id="atlas-01",
|
||||
address="127.0.0.1",
|
||||
services=[
|
||||
RegisteredService(
|
||||
service_id="atlas-01/chat/qwen",
|
||||
host_id="atlas-01",
|
||||
kind="chat",
|
||||
protocol="openai",
|
||||
endpoint="http://127.0.0.1:18091",
|
||||
assets=[{"asset_id": "qwen-test", "loaded": True}],
|
||||
state={"health": "healthy", "accept_requests": True},
|
||||
observed={"p50_latency_ms": 100},
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_successful_chat_request_is_audited_without_prompt_content(tmp_path: Path) -> None:
|
||||
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
|
||||
_register_chat_service(app)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-success"},
|
||||
json={
|
||||
"model": "qwen-test",
|
||||
"messages": [{"role": "user", "content": "private prompt text"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["x-request-id"] == "req-test-success"
|
||||
|
||||
row = app.state.registry.get_request_audit("req-test-success")
|
||||
assert row is not None
|
||||
assert row["operation"] == "chat"
|
||||
assert row["requested_model"] == "qwen-test"
|
||||
assert row["resolved_service_id"] == "atlas-01/chat/qwen"
|
||||
assert row["upstream_model"] == "qwen-test"
|
||||
assert row["provider_kind"] == "openai"
|
||||
assert row["success"] is True
|
||||
assert row["status_code"] == 200
|
||||
assert row["prompt_tokens"] == 7
|
||||
assert row["completion_tokens"] == 3
|
||||
assert row["total_tokens"] == 10
|
||||
assert "private prompt text" not in json.dumps(row)
|
||||
|
||||
|
||||
def test_failed_chat_route_is_audited(tmp_path: Path) -> None:
|
||||
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-failure"},
|
||||
json={
|
||||
"model": "missing-model",
|
||||
"messages": [{"role": "user", "content": "private failure prompt"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.headers["x-request-id"] == "req-test-failure"
|
||||
|
||||
row = app.state.registry.get_request_audit("req-test-failure")
|
||||
assert row is not None
|
||||
assert row["operation"] == "chat"
|
||||
assert row["requested_model"] == "missing-model"
|
||||
assert row["success"] is False
|
||||
assert row["status_code"] == 404
|
||||
assert row["error_type"] == "proxy_error"
|
||||
assert "private failure prompt" not in json.dumps(row)
|
||||
|
||||
|
||||
def test_admin_audit_endpoints_list_and_summarize_requests(tmp_path: Path) -> None:
|
||||
app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster()))
|
||||
_register_chat_service(app)
|
||||
client = TestClient(app)
|
||||
client.post(
|
||||
"/v1/chat/completions",
|
||||
headers={"X-Api-Key": "audit-key"},
|
||||
json={"model": "qwen-test", "messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
|
||||
listed = client.get("/v1/admin/audit/requests", headers={"X-Api-Key": "audit-key"})
|
||||
assert listed.status_code == 200
|
||||
assert listed.json()["data"][0]["requested_model"] == "qwen-test"
|
||||
|
||||
summary = client.get("/v1/admin/audit/summary", headers={"X-Api-Key": "audit-key"})
|
||||
assert summary.status_code == 200
|
||||
summary_row = summary.json()["data"][0]
|
||||
assert summary_row["requested_model"] == "qwen-test"
|
||||
assert summary_row["request_count"] == 1
|
||||
assert summary_row["success_count"] == 1
|
||||
assert summary_row["total_tokens"] == 10
|
||||
Loading…
Reference in New Issue