Add named client key storage primitives

This commit is contained in:
welsberr 2026-04-29 14:18:25 -04:00
parent 73629bf4f4
commit 3c34b870ac
3 changed files with 219 additions and 0 deletions

View File

@ -0,0 +1,39 @@
from __future__ import annotations
import hashlib
import hmac
import secrets
DEFAULT_KEY_PREFIX = "gh"
def generate_api_key(*, prefix: str = DEFAULT_KEY_PREFIX, token_bytes: int = 32) -> str:
"""Generate a URL-safe API key. The raw value is only shown once."""
token = secrets.token_urlsafe(token_bytes)
return f"{prefix}_{token}"
def hash_api_key(api_key: str, *, secret: str) -> str:
if not secret:
raise ValueError("key hash secret must not be empty")
digest = hmac.new(
secret.encode("utf-8"),
api_key.encode("utf-8"),
hashlib.sha256,
).hexdigest()
return f"hmac-sha256:{digest}"
def verify_api_key(api_key: str, key_hash: str, *, secret: str) -> bool:
try:
expected = hash_api_key(api_key, secret=secret)
except ValueError:
return False
return hmac.compare_digest(expected, key_hash)
def redact_api_key(api_key: str) -> str:
if len(api_key) <= 12:
return "***"
return f"{api_key[:6]}...{api_key[-4:]}"

View File

@ -77,6 +77,24 @@ class Registry:
observed_at REAL NOT NULL, observed_at REAL NOT NULL,
results_json TEXT NOT NULL results_json TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS client_keys (
key_id TEXT PRIMARY KEY,
key_hash TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL,
principal_type TEXT NOT NULL,
principal_ref TEXT NOT NULL,
role TEXT,
allowed_models_json TEXT NOT NULL DEFAULT '[]',
allowed_operations_json TEXT NOT NULL DEFAULT '[]',
monthly_budget_cents INTEGER,
monthly_token_limit INTEGER,
enabled INTEGER NOT NULL DEFAULT 1,
created_at REAL NOT NULL,
updated_at REAL NOT NULL,
last_used_at REAL,
notes TEXT
);
""" """
) )
@ -290,6 +308,88 @@ class Registry:
rows = conn.execute(query, params).fetchall() rows = conn.execute(query, params).fetchall()
return [self._benchmark_row_to_dict(row) for row in rows] return [self._benchmark_row_to_dict(row) for row in rows]
def create_client_key(
self,
*,
key_id: str,
key_hash: str,
display_name: str,
principal_type: str,
principal_ref: str,
role: str | None = None,
allowed_models: list[str] | None = None,
allowed_operations: list[str] | None = None,
monthly_budget_cents: int | None = None,
monthly_token_limit: int | None = None,
enabled: bool = True,
notes: str | None = None,
) -> dict:
now = time.time()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO client_keys (
key_id, key_hash, display_name, principal_type, principal_ref,
role, allowed_models_json, allowed_operations_json,
monthly_budget_cents, monthly_token_limit, enabled,
created_at, updated_at, last_used_at, notes
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?)
""",
(
key_id,
key_hash,
display_name,
principal_type,
principal_ref,
role,
_json_dumps(allowed_models or []),
_json_dumps(allowed_operations or []),
monthly_budget_cents,
monthly_token_limit,
1 if enabled else 0,
now,
now,
notes,
),
)
created = self.get_client_key(key_id)
if created is None:
raise RuntimeError(f"created client key {key_id!r} could not be loaded")
return created
def get_client_key(self, key_id: str) -> dict | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM client_keys WHERE key_id = ?", (key_id,)).fetchone()
return self._client_key_row_to_dict(row) if row is not None else None
def get_client_key_by_hash(self, key_hash: str) -> dict | None:
with self._connect() as conn:
row = conn.execute("SELECT * FROM client_keys WHERE key_hash = ?", (key_hash,)).fetchone()
return self._client_key_row_to_dict(row) if row is not None else None
def list_client_keys(self) -> list[dict]:
with self._connect() as conn:
rows = conn.execute("SELECT * FROM client_keys ORDER BY created_at, key_id").fetchall()
return [self._client_key_row_to_dict(row) for row in rows]
def set_client_key_enabled(self, key_id: str, enabled: bool) -> dict | None:
now = time.time()
with self._connect() as conn:
conn.execute(
"UPDATE client_keys SET enabled = ?, updated_at = ? WHERE key_id = ?",
(1 if enabled else 0, now, key_id),
)
return self.get_client_key(key_id)
def touch_client_key(self, key_id: str) -> None:
now = time.time()
with self._connect() as conn:
conn.execute(
"UPDATE client_keys SET last_used_at = ?, updated_at = ? WHERE key_id = ?",
(now, now, key_id),
)
def list_client_models(self) -> list[dict]: def list_client_models(self) -> list[dict]:
services = self.list_services() services = self.list_services()
roles = self.list_roles() roles = self.list_roles()
@ -807,6 +907,26 @@ class Registry:
"results": json.loads(row["results_json"]), "results": json.loads(row["results_json"]),
} }
@staticmethod
def _client_key_row_to_dict(row: sqlite3.Row) -> dict:
return {
"key_id": row["key_id"],
"key_hash": row["key_hash"],
"display_name": row["display_name"],
"principal_type": row["principal_type"],
"principal_ref": row["principal_ref"],
"role": row["role"],
"allowed_models": json.loads(row["allowed_models_json"]),
"allowed_operations": json.loads(row["allowed_operations_json"]),
"monthly_budget_cents": row["monthly_budget_cents"],
"monthly_token_limit": row["monthly_token_limit"],
"enabled": bool(row["enabled"]),
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_used_at": row["last_used_at"],
"notes": row["notes"],
}
def _tokenize_text(value: str) -> set[str]: def _tokenize_text(value: str) -> set[str]:
return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token} return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token}

View File

@ -0,0 +1,60 @@
from pathlib import Path
from geniehive_control.keys import generate_api_key, hash_api_key, redact_api_key, verify_api_key
from geniehive_control.registry import Registry
def test_api_key_hash_verify_and_redact() -> None:
raw_key = generate_api_key(prefix="gh_test")
key_hash = hash_api_key(raw_key, secret="test-secret")
assert raw_key.startswith("gh_test_")
assert key_hash.startswith("hmac-sha256:")
assert verify_api_key(raw_key, key_hash, secret="test-secret") is True
assert verify_api_key(raw_key + "-wrong", key_hash, secret="test-secret") is False
assert verify_api_key(raw_key, key_hash, secret="other-secret") is False
assert raw_key not in redact_api_key(raw_key)
def test_registry_client_key_lifecycle(tmp_path: Path) -> None:
registry = Registry(tmp_path / "geniehive.sqlite3")
raw_key = "gh_test_secret"
key_hash = hash_api_key(raw_key, secret="test-secret")
created = registry.create_client_key(
key_id="ck_test",
key_hash=key_hash,
display_name="Test User",
principal_type="person",
principal_ref="test-user",
role="developer",
allowed_models=["archive_migrator"],
allowed_operations=["chat"],
monthly_budget_cents=1000,
monthly_token_limit=20000,
notes="created by test",
)
assert created["key_id"] == "ck_test"
assert created["key_hash"] == key_hash
assert created["display_name"] == "Test User"
assert created["allowed_models"] == ["archive_migrator"]
assert created["allowed_operations"] == ["chat"]
assert created["enabled"] is True
assert created["last_used_at"] is None
listed = registry.list_client_keys()
assert [item["key_id"] for item in listed] == ["ck_test"]
by_hash = registry.get_client_key_by_hash(key_hash)
assert by_hash is not None
assert by_hash["principal_ref"] == "test-user"
disabled = registry.set_client_key_enabled("ck_test", False)
assert disabled is not None
assert disabled["enabled"] is False
registry.touch_client_key("ck_test")
touched = registry.get_client_key("ck_test")
assert touched is not None
assert touched["last_used_at"] is not None