diff --git a/src/geniehive_control/keys.py b/src/geniehive_control/keys.py new file mode 100644 index 0000000..750383c --- /dev/null +++ b/src/geniehive_control/keys.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import hashlib +import hmac +import secrets + + +DEFAULT_KEY_PREFIX = "gh" + + +def generate_api_key(*, prefix: str = DEFAULT_KEY_PREFIX, token_bytes: int = 32) -> str: + """Generate a URL-safe API key. The raw value is only shown once.""" + token = secrets.token_urlsafe(token_bytes) + return f"{prefix}_{token}" + + +def hash_api_key(api_key: str, *, secret: str) -> str: + if not secret: + raise ValueError("key hash secret must not be empty") + digest = hmac.new( + secret.encode("utf-8"), + api_key.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + return f"hmac-sha256:{digest}" + + +def verify_api_key(api_key: str, key_hash: str, *, secret: str) -> bool: + try: + expected = hash_api_key(api_key, secret=secret) + except ValueError: + return False + return hmac.compare_digest(expected, key_hash) + + +def redact_api_key(api_key: str) -> str: + if len(api_key) <= 12: + return "***" + return f"{api_key[:6]}...{api_key[-4:]}" diff --git a/src/geniehive_control/registry.py b/src/geniehive_control/registry.py index dc79cf3..7051551 100644 --- a/src/geniehive_control/registry.py +++ b/src/geniehive_control/registry.py @@ -77,6 +77,24 @@ class Registry: observed_at REAL NOT NULL, results_json TEXT NOT NULL ); + + CREATE TABLE IF NOT EXISTS client_keys ( + key_id TEXT PRIMARY KEY, + key_hash TEXT NOT NULL UNIQUE, + display_name TEXT NOT NULL, + principal_type TEXT NOT NULL, + principal_ref TEXT NOT NULL, + role TEXT, + allowed_models_json TEXT NOT NULL DEFAULT '[]', + allowed_operations_json TEXT NOT NULL DEFAULT '[]', + monthly_budget_cents INTEGER, + monthly_token_limit INTEGER, + enabled INTEGER NOT NULL DEFAULT 1, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + last_used_at REAL, + notes TEXT + ); """ ) @@ -290,6 +308,88 @@ class Registry: rows = conn.execute(query, params).fetchall() return [self._benchmark_row_to_dict(row) for row in rows] + def create_client_key( + self, + *, + key_id: str, + key_hash: str, + display_name: str, + principal_type: str, + principal_ref: str, + role: str | None = None, + allowed_models: list[str] | None = None, + allowed_operations: list[str] | None = None, + monthly_budget_cents: int | None = None, + monthly_token_limit: int | None = None, + enabled: bool = True, + notes: str | None = None, + ) -> dict: + now = time.time() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO client_keys ( + key_id, key_hash, display_name, principal_type, principal_ref, + role, allowed_models_json, allowed_operations_json, + monthly_budget_cents, monthly_token_limit, enabled, + created_at, updated_at, last_used_at, notes + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, NULL, ?) + """, + ( + key_id, + key_hash, + display_name, + principal_type, + principal_ref, + role, + _json_dumps(allowed_models or []), + _json_dumps(allowed_operations or []), + monthly_budget_cents, + monthly_token_limit, + 1 if enabled else 0, + now, + now, + notes, + ), + ) + created = self.get_client_key(key_id) + if created is None: + raise RuntimeError(f"created client key {key_id!r} could not be loaded") + return created + + def get_client_key(self, key_id: str) -> dict | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM client_keys WHERE key_id = ?", (key_id,)).fetchone() + return self._client_key_row_to_dict(row) if row is not None else None + + def get_client_key_by_hash(self, key_hash: str) -> dict | None: + with self._connect() as conn: + row = conn.execute("SELECT * FROM client_keys WHERE key_hash = ?", (key_hash,)).fetchone() + return self._client_key_row_to_dict(row) if row is not None else None + + def list_client_keys(self) -> list[dict]: + with self._connect() as conn: + rows = conn.execute("SELECT * FROM client_keys ORDER BY created_at, key_id").fetchall() + return [self._client_key_row_to_dict(row) for row in rows] + + def set_client_key_enabled(self, key_id: str, enabled: bool) -> dict | None: + now = time.time() + with self._connect() as conn: + conn.execute( + "UPDATE client_keys SET enabled = ?, updated_at = ? WHERE key_id = ?", + (1 if enabled else 0, now, key_id), + ) + return self.get_client_key(key_id) + + def touch_client_key(self, key_id: str) -> None: + now = time.time() + with self._connect() as conn: + conn.execute( + "UPDATE client_keys SET last_used_at = ?, updated_at = ? WHERE key_id = ?", + (now, now, key_id), + ) + def list_client_models(self) -> list[dict]: services = self.list_services() roles = self.list_roles() @@ -807,6 +907,26 @@ class Registry: "results": json.loads(row["results_json"]), } + @staticmethod + def _client_key_row_to_dict(row: sqlite3.Row) -> dict: + return { + "key_id": row["key_id"], + "key_hash": row["key_hash"], + "display_name": row["display_name"], + "principal_type": row["principal_type"], + "principal_ref": row["principal_ref"], + "role": row["role"], + "allowed_models": json.loads(row["allowed_models_json"]), + "allowed_operations": json.loads(row["allowed_operations_json"]), + "monthly_budget_cents": row["monthly_budget_cents"], + "monthly_token_limit": row["monthly_token_limit"], + "enabled": bool(row["enabled"]), + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "last_used_at": row["last_used_at"], + "notes": row["notes"], + } + def _tokenize_text(value: str) -> set[str]: return {token for token in re.split(r"[^a-z0-9]+", value.lower()) if token} diff --git a/tests/test_control_keys.py b/tests/test_control_keys.py new file mode 100644 index 0000000..f8fb230 --- /dev/null +++ b/tests/test_control_keys.py @@ -0,0 +1,60 @@ +from pathlib import Path + +from geniehive_control.keys import generate_api_key, hash_api_key, redact_api_key, verify_api_key +from geniehive_control.registry import Registry + + +def test_api_key_hash_verify_and_redact() -> None: + raw_key = generate_api_key(prefix="gh_test") + key_hash = hash_api_key(raw_key, secret="test-secret") + + assert raw_key.startswith("gh_test_") + assert key_hash.startswith("hmac-sha256:") + assert verify_api_key(raw_key, key_hash, secret="test-secret") is True + assert verify_api_key(raw_key + "-wrong", key_hash, secret="test-secret") is False + assert verify_api_key(raw_key, key_hash, secret="other-secret") is False + assert raw_key not in redact_api_key(raw_key) + + +def test_registry_client_key_lifecycle(tmp_path: Path) -> None: + registry = Registry(tmp_path / "geniehive.sqlite3") + raw_key = "gh_test_secret" + key_hash = hash_api_key(raw_key, secret="test-secret") + + created = registry.create_client_key( + key_id="ck_test", + key_hash=key_hash, + display_name="Test User", + principal_type="person", + principal_ref="test-user", + role="developer", + allowed_models=["archive_migrator"], + allowed_operations=["chat"], + monthly_budget_cents=1000, + monthly_token_limit=20000, + notes="created by test", + ) + + assert created["key_id"] == "ck_test" + assert created["key_hash"] == key_hash + assert created["display_name"] == "Test User" + assert created["allowed_models"] == ["archive_migrator"] + assert created["allowed_operations"] == ["chat"] + assert created["enabled"] is True + assert created["last_used_at"] is None + + listed = registry.list_client_keys() + assert [item["key_id"] for item in listed] == ["ck_test"] + + by_hash = registry.get_client_key_by_hash(key_hash) + assert by_hash is not None + assert by_hash["principal_ref"] == "test-user" + + disabled = registry.set_client_key_enabled("ck_test", False) + assert disabled is not None + assert disabled["enabled"] is False + + registry.touch_client_key("ck_test") + touched = registry.get_client_key("ck_test") + assert touched is not None + assert touched["last_used_at"] is not None