Update for auth step 1

This commit is contained in:
welsberr 2026-02-07 15:04:27 -05:00
parent b2e4518b3a
commit d7042b4a2b
9 changed files with 134 additions and 7 deletions

View File

@ -5,6 +5,15 @@ gateway:
host: 0.0.0.0
port: 8000
auth:
# Set these to enable auth. If empty, auth is disabled (scaffold/back-compat).
client_api_keys:
- "change-me-client-key-1"
node_api_keys:
- "change-me-node-key-1"
# Models may be:
# - type: proxy (static URL to an OpenAI-compatible upstream)
# - type: discovered (resolved from registered nodes by role)

View File

@ -4,6 +4,8 @@ listen_port: 8091
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
dispatcher_base_url: "http://127.0.0.1:8080"
# Optional auth key presented to dispatcher for /v1/nodes/* endpoints
dispatcher_node_key: "change-me-node-key-1"
dispatcher_roles: ["planner", "coder"]
heartbeat_interval_sec: 5

View File

@ -11,6 +11,9 @@ default_model: writer
gateway:
host: 0.0.0.0
port: 8000
auth:
client_api_keys: ["..."]
node_api_keys: ["..."]
models:
<alias>:
type: proxy | discovered
@ -61,4 +64,10 @@ Nodes register to `POST /v1/nodes/register`:
}
```
Security is intentionally omitted in this scaffold — add API keys or mTLS if the gateway is exposed beyond localhost.
If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key.
If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key.
Supported headers:
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`

View File

@ -26,6 +26,7 @@ This scaffold supports two patterns.
```bash
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
-H 'Content-Type: application/json' \
-H 'X-RoleMesh-Node-Key: <node-key>' \
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
```

View File

@ -5,12 +5,13 @@ import time
import uuid
from typing import Any, Dict
from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
from rolemesh_gateway.auth import require_client_auth, require_node_auth
router = APIRouter()
@ -30,7 +31,7 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
@router.get("/v1/models")
async def list_models(request: Request) -> Dict[str, Any]:
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
cfg: Config = request.app.state.cfg
registry: Registry = request.app.state.registry
@ -54,7 +55,7 @@ async def list_models(request: Request) -> Dict[str, Any]:
@router.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Any:
async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
cfg: Config = request.app.state.cfg
upstream: UpstreamClient = request.app.state.upstream
registry: Registry = request.app.state.registry
@ -133,7 +134,7 @@ async def ready(request: Request) -> JSONResponse:
@router.post("/v1/nodes/register")
async def register_node(request: Request) -> Dict[str, Any]:
async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
"""
Allow a remote machine to register which roles it serves.
@ -148,7 +149,7 @@ async def register_node(request: Request) -> Dict[str, Any]:
@router.post("/v1/nodes/heartbeat")
async def nodes_heartbeat(request: Request) -> Dict[str, Any]:
async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
"""
Allow a node agent to push status/metrics updates.

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
from fastapi import Header, HTTPException, Request, status
@dataclass(frozen=True)
class Principal:
kind: str # "client" | "node"
key_id: str # opaque identifier (e.g., last 6 chars)
def _extract_bearer(authorization: Optional[str]) -> Optional[str]:
if not authorization:
return None
parts = authorization.strip().split()
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
def _key_id(token: str) -> str:
return token[-6:] if len(token) >= 6 else token
def require_client_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
) -> Optional[Principal]:
"""Gate OpenAI-style endpoints with an API key.
Backward compatible: if no client keys are configured, this is a no-op.
Supported headers:
- Authorization: Bearer <key>
- X-Api-Key: <key>
"""
cfg = getattr(request.app.state, "cfg", None)
keys: Sequence[str] = []
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
keys = list(cfg.auth.client_api_keys or [])
if not keys:
return None # auth disabled
token = _extract_bearer(authorization) or x_api_key
if not token or token not in keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid API key.",
headers={"WWW-Authenticate": "Bearer"},
)
return Principal(kind="client", key_id=_key_id(token))
def require_node_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
) -> Optional[Principal]:
"""Gate node registration/heartbeat endpoints with a node key.
Backward compatible: if no node keys are configured, this is a no-op.
Supported headers:
- Authorization: Bearer <node_key>
- X-RoleMesh-Node-Key: <node_key>
"""
cfg = getattr(request.app.state, "cfg", None)
keys: Sequence[str] = []
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
keys = list(cfg.auth.node_api_keys or [])
if not keys:
return None # auth disabled
token = _extract_bearer(authorization) or x_rolemesh_node_key
if not token or token not in keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid node key.",
headers={"WWW-Authenticate": "Bearer"},
)
return Principal(kind="node", key_id=_key_id(token))

View File

@ -12,6 +12,17 @@ class GatewayConfig(BaseModel):
port: int = 8000
class AuthConfig(BaseModel):
# API keys for clients calling OpenAI-compatible endpoints.
# If empty, client auth is disabled (scaffold/back-compat mode).
client_api_keys: List[str] = Field(default_factory=list)
# API keys for node agents calling register/heartbeat endpoints.
# If empty, node auth is disabled (scaffold/back-compat mode).
node_api_keys: List[str] = Field(default_factory=list)
ModelType = Literal["proxy", "discovered"]
@ -39,6 +50,7 @@ class Config(BaseModel):
version: int = 1
default_model: str
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
auth: AuthConfig = Field(default_factory=AuthConfig)
models: Dict[str, ModelEntry] = Field(default_factory=dict)

View File

@ -33,5 +33,8 @@ class NodeAgentConfig(BaseModel):
dispatcher_roles: List[str] = Field(default_factory=list)
heartbeat_interval_sec: float = 5.0
# Optional auth key presented to the dispatcher for /v1/nodes/* endpoints
dispatcher_node_key: Optional[str] = None
# llama-server binary name/path
llama_server_bin: str = "llama-server"

View File

@ -127,7 +127,10 @@ async def _heartbeat_loop(app: FastAPI) -> None:
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
}
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
await http.post(url, json=payload)
headers = {}
if cfg.dispatcher_node_key:
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
await http.post(url, json=payload, headers=headers)
except Exception:
pass
await asyncio.sleep(cfg.heartbeat_interval_sec)