Enable opt-in named client key auth
This commit is contained in:
parent
3c34b870ac
commit
04b9bec83a
|
|
@ -1,7 +1,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
from .keys import hash_api_key
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClientContext:
|
||||
auth_kind: str
|
||||
key_id: str | None = None
|
||||
display_name: str | None = None
|
||||
principal_type: str | None = None
|
||||
principal_ref: str | None = None
|
||||
role: str | None = None
|
||||
allowed_models: tuple[str, ...] = ()
|
||||
allowed_operations: tuple[str, ...] = ()
|
||||
|
||||
|
||||
def _check_key(request: Request, allowed_keys: list[str], header_name: str) -> None:
|
||||
if not allowed_keys:
|
||||
|
|
@ -15,9 +32,61 @@ def _check_key(request: Request, allowed_keys: list[str], header_name: str) -> N
|
|||
)
|
||||
|
||||
|
||||
def require_client_auth(request: Request) -> None:
|
||||
def _set_client_context(request: Request, context: ClientContext) -> None:
|
||||
request.state.client_context = context
|
||||
|
||||
|
||||
def require_client_auth(request: Request) -> ClientContext:
|
||||
cfg = request.app.state.cfg
|
||||
_check_key(request, cfg.auth.client_api_keys, "X-Api-Key")
|
||||
provided = request.headers.get("X-Api-Key")
|
||||
|
||||
if cfg.auth.client_api_keys and provided in cfg.auth.client_api_keys:
|
||||
context = ClientContext(auth_kind="static")
|
||||
_set_client_context(request, context)
|
||||
return context
|
||||
|
||||
if cfg.auth.enable_named_client_keys:
|
||||
if not provided:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="unauthorized",
|
||||
)
|
||||
secret = os.environ.get(cfg.auth.key_hash_secret_env)
|
||||
if not secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"{cfg.auth.key_hash_secret_env} is required for named client keys",
|
||||
)
|
||||
key_hash = hash_api_key(provided, secret=secret)
|
||||
key_row = request.app.state.registry.get_client_key_by_hash(key_hash)
|
||||
if key_row is None or not key_row["enabled"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="unauthorized",
|
||||
)
|
||||
request.app.state.registry.touch_client_key(key_row["key_id"])
|
||||
context = ClientContext(
|
||||
auth_kind="named",
|
||||
key_id=key_row["key_id"],
|
||||
display_name=key_row["display_name"],
|
||||
principal_type=key_row["principal_type"],
|
||||
principal_ref=key_row["principal_ref"],
|
||||
role=key_row["role"],
|
||||
allowed_models=tuple(key_row["allowed_models"]),
|
||||
allowed_operations=tuple(key_row["allowed_operations"]),
|
||||
)
|
||||
_set_client_context(request, context)
|
||||
return context
|
||||
|
||||
if cfg.auth.client_api_keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="unauthorized",
|
||||
)
|
||||
|
||||
context = ClientContext(auth_kind="development")
|
||||
_set_client_context(request, context)
|
||||
return context
|
||||
|
||||
|
||||
def require_node_auth(request: Request) -> None:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from geniehive_control.auth import require_client_auth
|
||||
from geniehive_control.keys import hash_api_key
|
||||
from geniehive_control.main import create_app
|
||||
|
||||
|
||||
def _write_config(tmp_path: Path, body: str) -> Path:
|
||||
config_path = tmp_path / "control.yaml"
|
||||
config_path.write_text(body)
|
||||
return config_path
|
||||
|
||||
|
||||
def test_static_client_key_auth_still_works(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
f"""
|
||||
auth:
|
||||
client_api_keys:
|
||||
- static-key
|
||||
storage:
|
||||
sqlite_path: "{tmp_path / 'geniehive.sqlite3'}"
|
||||
""",
|
||||
)
|
||||
app = create_app(config_path)
|
||||
client = TestClient(app)
|
||||
|
||||
assert client.get("/v1/models").status_code == 401
|
||||
ok = client.get("/v1/models", headers={"X-Api-Key": "static-key"})
|
||||
assert ok.status_code == 200
|
||||
|
||||
|
||||
def test_empty_static_keys_still_allow_development_access(tmp_path: Path) -> None:
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
f"""
|
||||
storage:
|
||||
sqlite_path: "{tmp_path / 'geniehive.sqlite3'}"
|
||||
""",
|
||||
)
|
||||
app = create_app(config_path)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_named_client_key_auth_when_enabled(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("GENIEHIVE_KEY_HASH_SECRET", "test-secret")
|
||||
db_path = tmp_path / "geniehive.sqlite3"
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
f"""
|
||||
auth:
|
||||
enable_named_client_keys: true
|
||||
storage:
|
||||
sqlite_path: "{db_path}"
|
||||
""",
|
||||
)
|
||||
app = create_app(config_path)
|
||||
raw_key = "gh_test_named"
|
||||
app.state.registry.create_client_key(
|
||||
key_id="ck_named",
|
||||
key_hash=hash_api_key(raw_key, secret="test-secret"),
|
||||
display_name="Named User",
|
||||
principal_type="person",
|
||||
principal_ref="named-user",
|
||||
role="developer",
|
||||
allowed_models=["archive_migrator"],
|
||||
allowed_operations=["chat"],
|
||||
)
|
||||
|
||||
@app.get("/_test/client-context")
|
||||
async def client_context(request: Request, _=Depends(require_client_auth)) -> dict:
|
||||
context = request.state.client_context
|
||||
return {
|
||||
"auth_kind": context.auth_kind,
|
||||
"key_id": context.key_id,
|
||||
"principal_ref": context.principal_ref,
|
||||
"allowed_models": list(context.allowed_models),
|
||||
"allowed_operations": list(context.allowed_operations),
|
||||
}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
missing = client.get("/_test/client-context")
|
||||
assert missing.status_code == 401
|
||||
|
||||
bad = client.get("/_test/client-context", headers={"X-Api-Key": "wrong"})
|
||||
assert bad.status_code == 401
|
||||
|
||||
ok = client.get("/_test/client-context", headers={"X-Api-Key": raw_key})
|
||||
assert ok.status_code == 200
|
||||
assert ok.json() == {
|
||||
"auth_kind": "named",
|
||||
"key_id": "ck_named",
|
||||
"principal_ref": "named-user",
|
||||
"allowed_models": ["archive_migrator"],
|
||||
"allowed_operations": ["chat"],
|
||||
}
|
||||
touched = app.state.registry.get_client_key("ck_named")
|
||||
assert touched is not None
|
||||
assert touched["last_used_at"] is not None
|
||||
|
||||
|
||||
def test_disabled_named_client_key_fails(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("GENIEHIVE_KEY_HASH_SECRET", "test-secret")
|
||||
db_path = tmp_path / "geniehive.sqlite3"
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
f"""
|
||||
auth:
|
||||
enable_named_client_keys: true
|
||||
storage:
|
||||
sqlite_path: "{db_path}"
|
||||
""",
|
||||
)
|
||||
app = create_app(config_path)
|
||||
raw_key = "gh_test_disabled"
|
||||
app.state.registry.create_client_key(
|
||||
key_id="ck_disabled",
|
||||
key_hash=hash_api_key(raw_key, secret="test-secret"),
|
||||
display_name="Disabled User",
|
||||
principal_type="person",
|
||||
principal_ref="disabled-user",
|
||||
enabled=False,
|
||||
)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/v1/models", headers={"X-Api-Key": raw_key})
|
||||
assert response.status_code == 401
|
||||
Loading…
Reference in New Issue