diff --git a/docs/foundation_gateway_baseline.md b/docs/foundation_gateway_baseline.md index 6873b74..6e41628 100644 --- a/docs/foundation_gateway_baseline.md +++ b/docs/foundation_gateway_baseline.md @@ -56,10 +56,11 @@ python -m pytest -q tests Expected current result at baseline: all tests pass. Current verification result after adding the Foundation roadmap, config profile -scaffold, named client key storage, opt-in named auth, and admin key endpoints: +scaffold, named client key storage, opt-in named auth, admin key endpoints, and +request audit logging: ```text -58 passed +61 passed ``` ## Known Constraints diff --git a/docs/foundation_gateway_roadmap.md b/docs/foundation_gateway_roadmap.md index e2af70e..61107ce 100644 --- a/docs/foundation_gateway_roadmap.md +++ b/docs/foundation_gateway_roadmap.md @@ -185,6 +185,10 @@ Acceptance: Goal: make production requests attributable without storing prompt or completion content. +Status: implemented for chat, embeddings, and transcription request wrappers. +Audit logging is disabled by default and enabled by `audit.enabled`. Admin audit +read endpoints are only mounted when `admin_api.enabled` is true. + Tasks: - Add request ID generation from `X-Request-Id` or UUID. diff --git a/src/geniehive_control/main.py b/src/geniehive_control/main.py index 7164f68..9501940 100644 --- a/src/geniehive_control/main.py +++ b/src/geniehive_control/main.py @@ -1,7 +1,9 @@ from __future__ import annotations import asyncio +import json import os +import time import uuid from contextlib import asynccontextmanager, suppress from pathlib import Path @@ -17,6 +19,7 @@ from .models import BenchmarkIngestRequest, HostHeartbeat, HostRegistration, Rou from .probe import ServiceProber from .roles import load_role_catalog from .registry import Registry +from .routing import choose_upstream_model_id from .upstream import UpstreamClient, UpstreamError @@ -70,6 +73,86 @@ def create_app( if key != "key_hash" } + def _request_id(request: Request) -> str: + return request.headers.get("X-Request-Id") or f"req_{uuid.uuid4().hex}" + + def _client_context(request: Request): + return getattr(request.state, "client_context", None) + + def _route_audit_metadata(reg: Registry, requested_model: str | None, *, kind: str) -> dict: + if not requested_model: + return { + "requested_model": None, + "resolved_service_id": None, + "resolved_host_id": None, + "upstream_model": None, + "provider_kind": None, + } + resolved = reg.resolve_route(requested_model, kind=kind) + service = resolved.get("service") if resolved else None + if not service: + return { + "requested_model": requested_model, + "resolved_service_id": None, + "resolved_host_id": None, + "upstream_model": None, + "provider_kind": None, + } + return { + "requested_model": requested_model, + "resolved_service_id": service.get("service_id"), + "resolved_host_id": service.get("host_id"), + "upstream_model": choose_upstream_model_id(requested_model, service), + "provider_kind": service.get("protocol"), + } + + def _usage_from_response(response: object) -> dict[str, int | None]: + usage = response.get("usage", {}) if isinstance(response, dict) else {} + return { + "prompt_tokens": usage.get("prompt_tokens") if isinstance(usage, dict) else None, + "completion_tokens": usage.get("completion_tokens") if isinstance(usage, dict) else None, + "total_tokens": usage.get("total_tokens") if isinstance(usage, dict) else None, + } + + def _audit_request( + request: Request, + *, + request_id: str, + operation: str, + route_metadata: dict, + started_at: float, + status_code: int, + success: bool, + response: object | None = None, + error_type: str | None = None, + input_bytes: int | None = None, + output_bytes: int | None = None, + ) -> None: + if not cfg.audit.enabled: + return + context = _client_context(request) + usage = _usage_from_response(response) + request.app.state.registry.record_request_audit( + request_id=request_id, + key_id=getattr(context, "key_id", None), + principal_type=getattr(context, "principal_type", None), + principal_ref=getattr(context, "principal_ref", None), + operation=operation, + requested_model=route_metadata.get("requested_model"), + resolved_service_id=route_metadata.get("resolved_service_id"), + resolved_host_id=route_metadata.get("resolved_host_id"), + upstream_model=route_metadata.get("upstream_model"), + provider_kind=route_metadata.get("provider_kind"), + started_at=started_at, + finished_at=time.time(), + status_code=status_code, + success=success, + error_type=error_type, + input_bytes=input_bytes, + output_bytes=output_bytes, + **usage, + ) + if cfg.admin_api.enabled: @app.post("/v1/admin/client-keys") async def create_client_key(request: Request, _=Depends(require_admin_auth)) -> dict: @@ -126,6 +209,41 @@ def create_app( return JSONResponse(status_code=404, content={"error": "unknown_client_key", "key_id": key_id}) return {"status": "ok", "client_key": _public_client_key(updated)} + @app.get("/v1/admin/audit/requests") + async def list_audit_requests( + request: Request, + key_id: str | None = None, + principal_ref: str | None = None, + operation: str | None = None, + model: str | None = None, + success: bool | None = None, + limit: int = 100, + _=Depends(require_admin_auth), + ) -> dict: + if not cfg.audit.enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="audit logging is not enabled", + ) + rows = request.app.state.registry.list_request_audit( + key_id=key_id, + principal_ref=principal_ref, + operation=operation, + model=model, + success=success, + limit=limit, + ) + return {"object": "list", "data": rows} + + @app.get("/v1/admin/audit/summary") + async def audit_summary(request: Request, _=Depends(require_admin_auth)) -> dict: + if not cfg.audit.enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="audit logging is not enabled", + ) + return {"object": "list", "data": request.app.state.registry.request_audit_summary()} + @app.post("/v1/nodes/register") async def register_node(request: Request, _=Depends(require_node_auth)) -> dict: payload = await request.json() @@ -155,45 +273,142 @@ def create_app( body = await request.json() reg: Registry = request.app.state.registry up: UpstreamClient = request.app.state.upstream + request_id = _request_id(request) + started_at = time.time() + route_metadata = _route_audit_metadata(reg, body.get("model"), kind="chat") + input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8")) try: if body.get("stream"): # Resolve route eagerly so ProxyError is raised before streaming starts. service, upstream_body = _prepare_chat_upstream(body, registry=reg) + _audit_request( + request, + request_id=request_id, + operation="chat", + route_metadata=route_metadata, + started_at=started_at, + status_code=200, + success=True, + input_bytes=input_bytes, + ) return StreamingResponse( stream_chat_completion(service, upstream_body, upstream=up), media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "X-Request-Id": request_id}, ) - return await proxy_chat_completion(body, registry=reg, upstream=up) + response = await proxy_chat_completion(body, registry=reg, upstream=up) + output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None + _audit_request( + request, + request_id=request_id, + operation="chat", + route_metadata=route_metadata, + started_at=started_at, + status_code=200, + success=True, + response=response, + input_bytes=input_bytes, + output_bytes=output_bytes, + ) + return JSONResponse(content=response, headers={"X-Request-Id": request_id}) except ProxyError as exc: + _audit_request( + request, + request_id=request_id, + operation="chat", + route_metadata=route_metadata, + started_at=started_at, + status_code=exc.status_code, + success=False, + error_type="proxy_error", + input_bytes=input_bytes, + ) return JSONResponse( status_code=exc.status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "chat_proxy_error"}}, + headers={"X-Request-Id": request_id}, ) except UpstreamError as exc: + status_code = exc.status_code or 502 + _audit_request( + request, + request_id=request_id, + operation="chat", + route_metadata=route_metadata, + started_at=started_at, + status_code=status_code, + success=False, + error_type="upstream_error", + input_bytes=input_bytes, + ) return JSONResponse( - status_code=exc.status_code or 502, + status_code=status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}}, + headers={"X-Request-Id": request_id}, ) @app.post("/v1/embeddings") async def embeddings(request: Request, _=Depends(require_client_auth)): body = await request.json() + reg: Registry = request.app.state.registry + request_id = _request_id(request) + started_at = time.time() + route_metadata = _route_audit_metadata(reg, body.get("model"), kind="embeddings") + input_bytes = len(json.dumps(body, separators=(",", ":")).encode("utf-8")) try: - return await proxy_embeddings( + response = await proxy_embeddings( body, - registry=request.app.state.registry, + registry=reg, upstream=request.app.state.upstream, ) + output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None + _audit_request( + request, + request_id=request_id, + operation="embeddings", + route_metadata=route_metadata, + started_at=started_at, + status_code=200, + success=True, + response=response, + input_bytes=input_bytes, + output_bytes=output_bytes, + ) + return JSONResponse(content=response, headers={"X-Request-Id": request_id}) except ProxyError as exc: + _audit_request( + request, + request_id=request_id, + operation="embeddings", + route_metadata=route_metadata, + started_at=started_at, + status_code=exc.status_code, + success=False, + error_type="proxy_error", + input_bytes=input_bytes, + ) return JSONResponse( status_code=exc.status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "embeddings_proxy_error"}}, + headers={"X-Request-Id": request_id}, ) except UpstreamError as exc: + status_code = exc.status_code or 502 + _audit_request( + request, + request_id=request_id, + operation="embeddings", + route_metadata=route_metadata, + started_at=started_at, + status_code=status_code, + success=False, + error_type="upstream_error", + input_bytes=input_bytes, + ) return JSONResponse( - status_code=exc.status_code or 502, + status_code=status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}}, + headers={"X-Request-Id": request_id}, ) @app.post("/v1/audio/transcriptions") @@ -207,8 +422,11 @@ def create_app( temperature: float | None = Form(None), _=Depends(require_client_auth), ): + request_id = _request_id(request) + started_at = time.time() + route_metadata = _route_audit_metadata(request.app.state.registry, model, kind="transcription") try: - return await proxy_transcription( + response = await proxy_transcription( model=model, file=file, language=language, @@ -218,15 +436,51 @@ def create_app( registry=request.app.state.registry, upstream=request.app.state.upstream, ) + output_bytes = len(json.dumps(response, separators=(",", ":")).encode("utf-8")) if isinstance(response, dict) else None + _audit_request( + request, + request_id=request_id, + operation="transcription", + route_metadata=route_metadata, + started_at=started_at, + status_code=200, + success=True, + response=response, + output_bytes=output_bytes, + ) + return JSONResponse(content=response, headers={"X-Request-Id": request_id}) except ProxyError as exc: + _audit_request( + request, + request_id=request_id, + operation="transcription", + route_metadata=route_metadata, + started_at=started_at, + status_code=exc.status_code, + success=False, + error_type="proxy_error", + ) return JSONResponse( status_code=exc.status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "transcription_proxy_error"}}, + headers={"X-Request-Id": request_id}, ) except UpstreamError as exc: + status_code = exc.status_code or 502 + _audit_request( + request, + request_id=request_id, + operation="transcription", + route_metadata=route_metadata, + started_at=started_at, + status_code=status_code, + success=False, + error_type="upstream_error", + ) return JSONResponse( - status_code=exc.status_code or 502, + status_code=status_code, content={"error": {"message": str(exc), "type": "geniehive_error", "code": "upstream_error"}}, + headers={"X-Request-Id": request_id}, ) @app.get("/v1/cluster/services") diff --git a/src/geniehive_control/registry.py b/src/geniehive_control/registry.py index 7051551..ffe7e9d 100644 --- a/src/geniehive_control/registry.py +++ b/src/geniehive_control/registry.py @@ -95,6 +95,32 @@ class Registry: last_used_at REAL, notes TEXT ); + + CREATE TABLE IF NOT EXISTS request_audit_log ( + request_id TEXT PRIMARY KEY, + key_id TEXT, + principal_type TEXT, + principal_ref TEXT, + operation TEXT NOT NULL, + requested_model TEXT, + resolved_service_id TEXT, + resolved_host_id TEXT, + upstream_model TEXT, + provider_kind TEXT, + started_at REAL NOT NULL, + finished_at REAL NOT NULL, + duration_ms REAL NOT NULL, + status_code INTEGER NOT NULL, + success INTEGER NOT NULL, + error_type TEXT, + prompt_tokens INTEGER, + completion_tokens INTEGER, + total_tokens INTEGER, + estimated_cost_cents REAL, + input_bytes INTEGER, + output_bytes INTEGER, + metadata_json TEXT NOT NULL DEFAULT '{}' + ); """ ) @@ -390,6 +416,145 @@ class Registry: (now, now, key_id), ) + def record_request_audit( + self, + *, + request_id: str, + key_id: str | None, + principal_type: str | None, + principal_ref: str | None, + operation: str, + requested_model: str | None, + resolved_service_id: str | None, + resolved_host_id: str | None, + upstream_model: str | None, + provider_kind: str | None, + started_at: float, + finished_at: float, + status_code: int, + success: bool, + error_type: str | None = None, + prompt_tokens: int | None = None, + completion_tokens: int | None = None, + total_tokens: int | None = None, + estimated_cost_cents: float | None = None, + input_bytes: int | None = None, + output_bytes: int | None = None, + metadata: dict | None = None, + ) -> dict: + duration_ms = max(0.0, (finished_at - started_at) * 1000.0) + with self._connect() as conn: + conn.execute( + """ + INSERT INTO request_audit_log ( + request_id, key_id, principal_type, principal_ref, + operation, requested_model, resolved_service_id, + resolved_host_id, upstream_model, provider_kind, + started_at, finished_at, duration_ms, status_code, success, + error_type, prompt_tokens, completion_tokens, total_tokens, + estimated_cost_cents, input_bytes, output_bytes, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + request_id, + key_id, + principal_type, + principal_ref, + operation, + requested_model, + resolved_service_id, + resolved_host_id, + upstream_model, + provider_kind, + started_at, + finished_at, + duration_ms, + status_code, + 1 if success else 0, + error_type, + prompt_tokens, + completion_tokens, + total_tokens, + estimated_cost_cents, + input_bytes, + output_bytes, + _json_dumps(metadata or {}), + ), + ) + row = self.get_request_audit(request_id) + if row is None: + raise RuntimeError(f"created audit row {request_id!r} could not be loaded") + return row + + def get_request_audit(self, request_id: str) -> dict | None: + with self._connect() as conn: + row = conn.execute( + "SELECT * FROM request_audit_log WHERE request_id = ?", + (request_id,), + ).fetchone() + return self._request_audit_row_to_dict(row) if row is not None else None + + def list_request_audit( + self, + *, + key_id: str | None = None, + principal_ref: str | None = None, + operation: str | None = None, + model: str | None = None, + success: bool | None = None, + limit: int = 100, + ) -> list[dict]: + query = "SELECT * FROM request_audit_log" + clauses = [] + params: list[object] = [] + if key_id: + clauses.append("key_id = ?") + params.append(key_id) + if principal_ref: + clauses.append("principal_ref = ?") + params.append(principal_ref) + if operation: + clauses.append("operation = ?") + params.append(operation) + if model: + clauses.append("requested_model = ?") + params.append(model) + if success is not None: + clauses.append("success = ?") + params.append(1 if success else 0) + if clauses: + query += " WHERE " + " AND ".join(clauses) + query += " ORDER BY started_at DESC LIMIT ?" + params.append(max(1, min(limit, 1000))) + with self._connect() as conn: + rows = conn.execute(query, params).fetchall() + return [self._request_audit_row_to_dict(row) for row in rows] + + def request_audit_summary(self) -> list[dict]: + with self._connect() as conn: + rows = conn.execute( + """ + SELECT + key_id, + principal_ref, + operation, + requested_model, + COUNT(*) AS request_count, + SUM(success) AS success_count, + SUM(CASE WHEN success = 0 THEN 1 ELSE 0 END) AS failure_count, + SUM(COALESCE(prompt_tokens, 0)) AS prompt_tokens, + SUM(COALESCE(completion_tokens, 0)) AS completion_tokens, + SUM(COALESCE(total_tokens, 0)) AS total_tokens, + SUM(COALESCE(estimated_cost_cents, 0)) AS estimated_cost_cents + FROM request_audit_log + GROUP BY key_id, principal_ref, operation, requested_model + ORDER BY request_count DESC, requested_model + """ + ).fetchall() + return [dict(row) for row in rows] + def list_client_models(self) -> list[dict]: services = self.list_services() roles = self.list_roles() @@ -927,6 +1092,34 @@ class Registry: "notes": row["notes"], } + @staticmethod + def _request_audit_row_to_dict(row: sqlite3.Row) -> dict: + return { + "request_id": row["request_id"], + "key_id": row["key_id"], + "principal_type": row["principal_type"], + "principal_ref": row["principal_ref"], + "operation": row["operation"], + "requested_model": row["requested_model"], + "resolved_service_id": row["resolved_service_id"], + "resolved_host_id": row["resolved_host_id"], + "upstream_model": row["upstream_model"], + "provider_kind": row["provider_kind"], + "started_at": row["started_at"], + "finished_at": row["finished_at"], + "duration_ms": row["duration_ms"], + "status_code": row["status_code"], + "success": bool(row["success"]), + "error_type": row["error_type"], + "prompt_tokens": row["prompt_tokens"], + "completion_tokens": row["completion_tokens"], + "total_tokens": row["total_tokens"], + "estimated_cost_cents": row["estimated_cost_cents"], + "input_bytes": row["input_bytes"], + "output_bytes": row["output_bytes"], + "metadata": json.loads(row["metadata_json"]), + } + def _tokenize_text(value: str) -> set[str]: return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token} diff --git a/tests/test_control_audit.py b/tests/test_control_audit.py new file mode 100644 index 0000000..c04ae7d --- /dev/null +++ b/tests/test_control_audit.py @@ -0,0 +1,154 @@ +import json +from pathlib import Path + +from fastapi.testclient import TestClient + +from geniehive_control.main import create_app +from geniehive_control.models import HostRegistration, RegisteredService +from geniehive_control.upstream import UpstreamClient + + +class _FakeResponse: + def __init__(self, payload: dict, status_code: int = 200) -> None: + self._payload = payload + self.status_code = status_code + self.text = str(payload) + + def json(self) -> dict: + return self._payload + + +class _UsagePoster: + async def post(self, url: str, *, json: dict, headers: dict[str, str] | None = None) -> _FakeResponse: + return _FakeResponse( + { + "object": "chat.completion", + "model": json["model"], + "choices": [{"index": 0, "message": {"role": "assistant", "content": "done"}}], + "usage": { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, + } + ) + + +def _write_audit_config(tmp_path: Path) -> Path: + config_path = tmp_path / "control.yaml" + config_path.write_text( + f""" +auth: + client_api_keys: + - audit-key +audit: + enabled: true +admin_api: + enabled: true +storage: + sqlite_path: "{tmp_path / 'geniehive.sqlite3'}" +""" + ) + return config_path + + +def _register_chat_service(app) -> None: + app.state.registry.register_host( + HostRegistration( + host_id="atlas-01", + address="127.0.0.1", + services=[ + RegisteredService( + service_id="atlas-01/chat/qwen", + host_id="atlas-01", + kind="chat", + protocol="openai", + endpoint="http://127.0.0.1:18091", + assets=[{"asset_id": "qwen-test", "loaded": True}], + state={"health": "healthy", "accept_requests": True}, + observed={"p50_latency_ms": 100}, + ) + ], + ) + ) + + +def test_successful_chat_request_is_audited_without_prompt_content(tmp_path: Path) -> None: + app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster())) + _register_chat_service(app) + client = TestClient(app) + + response = client.post( + "/v1/chat/completions", + headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-success"}, + json={ + "model": "qwen-test", + "messages": [{"role": "user", "content": "private prompt text"}], + }, + ) + + assert response.status_code == 200 + assert response.headers["x-request-id"] == "req-test-success" + + row = app.state.registry.get_request_audit("req-test-success") + assert row is not None + assert row["operation"] == "chat" + assert row["requested_model"] == "qwen-test" + assert row["resolved_service_id"] == "atlas-01/chat/qwen" + assert row["upstream_model"] == "qwen-test" + assert row["provider_kind"] == "openai" + assert row["success"] is True + assert row["status_code"] == 200 + assert row["prompt_tokens"] == 7 + assert row["completion_tokens"] == 3 + assert row["total_tokens"] == 10 + assert "private prompt text" not in json.dumps(row) + + +def test_failed_chat_route_is_audited(tmp_path: Path) -> None: + app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster())) + client = TestClient(app) + + response = client.post( + "/v1/chat/completions", + headers={"X-Api-Key": "audit-key", "X-Request-Id": "req-test-failure"}, + json={ + "model": "missing-model", + "messages": [{"role": "user", "content": "private failure prompt"}], + }, + ) + + assert response.status_code == 404 + assert response.headers["x-request-id"] == "req-test-failure" + + row = app.state.registry.get_request_audit("req-test-failure") + assert row is not None + assert row["operation"] == "chat" + assert row["requested_model"] == "missing-model" + assert row["success"] is False + assert row["status_code"] == 404 + assert row["error_type"] == "proxy_error" + assert "private failure prompt" not in json.dumps(row) + + +def test_admin_audit_endpoints_list_and_summarize_requests(tmp_path: Path) -> None: + app = create_app(_write_audit_config(tmp_path), upstream_client=UpstreamClient(client=_UsagePoster())) + _register_chat_service(app) + client = TestClient(app) + client.post( + "/v1/chat/completions", + headers={"X-Api-Key": "audit-key"}, + json={"model": "qwen-test", "messages": [{"role": "user", "content": "hello"}]}, + ) + + listed = client.get("/v1/admin/audit/requests", headers={"X-Api-Key": "audit-key"}) + assert listed.status_code == 200 + assert listed.json()["data"][0]["requested_model"] == "qwen-test" + + summary = client.get("/v1/admin/audit/summary", headers={"X-Api-Key": "audit-key"}) + assert summary.status_code == 200 + summary_row = summary.json()["data"][0] + assert summary_row["requested_model"] == "qwen-test" + assert summary_row["request_count"] == 1 + assert summary_row["success_count"] == 1 + assert summary_row["total_tokens"] == 10