Update for auth step 1

This commit is contained in:
welsberr 2026-02-07 15:04:27 -05:00
parent b2e4518b3a
commit d7042b4a2b
9 changed files with 134 additions and 7 deletions

View File

@ -5,6 +5,15 @@ gateway:
host: 0.0.0.0 host: 0.0.0.0
port: 8000 port: 8000
auth:
# Set these to enable auth. If empty, auth is disabled (scaffold/back-compat).
client_api_keys:
- "change-me-client-key-1"
node_api_keys:
- "change-me-node-key-1"
# Models may be: # Models may be:
# - type: proxy (static URL to an OpenAI-compatible upstream) # - type: proxy (static URL to an OpenAI-compatible upstream)
# - type: discovered (resolved from registered nodes by role) # - type: discovered (resolved from registered nodes by role)

View File

@ -4,6 +4,8 @@ listen_port: 8091
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat. # Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
dispatcher_base_url: "http://127.0.0.1:8080" dispatcher_base_url: "http://127.0.0.1:8080"
# Optional auth key presented to dispatcher for /v1/nodes/* endpoints
dispatcher_node_key: "change-me-node-key-1"
dispatcher_roles: ["planner", "coder"] dispatcher_roles: ["planner", "coder"]
heartbeat_interval_sec: 5 heartbeat_interval_sec: 5

View File

@ -11,6 +11,9 @@ default_model: writer
gateway: gateway:
host: 0.0.0.0 host: 0.0.0.0
port: 8000 port: 8000
auth:
client_api_keys: ["..."]
node_api_keys: ["..."]
models: models:
<alias>: <alias>:
type: proxy | discovered type: proxy | discovered
@ -61,4 +64,10 @@ Nodes register to `POST /v1/nodes/register`:
} }
``` ```
Security is intentionally omitted in this scaffold — add API keys or mTLS if the gateway is exposed beyond localhost. If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key.
If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key.
Supported headers:
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`

View File

@ -26,6 +26,7 @@ This scaffold supports two patterns.
```bash ```bash
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \ curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-H 'X-RoleMesh-Node-Key: <node-key>' \
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}' -d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
``` ```

View File

@ -5,12 +5,13 @@ import time
import uuid import uuid
from typing import Any, Dict from typing import Any, Dict
from fastapi import APIRouter, Request from fastapi import APIRouter, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
from rolemesh_gateway.auth import require_client_auth, require_node_auth
router = APIRouter() router = APIRouter()
@ -30,7 +31,7 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
@router.get("/v1/models") @router.get("/v1/models")
async def list_models(request: Request) -> Dict[str, Any]: async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
cfg: Config = request.app.state.cfg cfg: Config = request.app.state.cfg
registry: Registry = request.app.state.registry registry: Registry = request.app.state.registry
@ -54,7 +55,7 @@ async def list_models(request: Request) -> Dict[str, Any]:
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def chat_completions(request: Request) -> Any: async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
cfg: Config = request.app.state.cfg cfg: Config = request.app.state.cfg
upstream: UpstreamClient = request.app.state.upstream upstream: UpstreamClient = request.app.state.upstream
registry: Registry = request.app.state.registry registry: Registry = request.app.state.registry
@ -133,7 +134,7 @@ async def ready(request: Request) -> JSONResponse:
@router.post("/v1/nodes/register") @router.post("/v1/nodes/register")
async def register_node(request: Request) -> Dict[str, Any]: async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
""" """
Allow a remote machine to register which roles it serves. Allow a remote machine to register which roles it serves.
@ -148,7 +149,7 @@ async def register_node(request: Request) -> Dict[str, Any]:
@router.post("/v1/nodes/heartbeat") @router.post("/v1/nodes/heartbeat")
async def nodes_heartbeat(request: Request) -> Dict[str, Any]: async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
""" """
Allow a node agent to push status/metrics updates. Allow a node agent to push status/metrics updates.

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Sequence
from fastapi import Header, HTTPException, Request, status
@dataclass(frozen=True)
class Principal:
kind: str # "client" | "node"
key_id: str # opaque identifier (e.g., last 6 chars)
def _extract_bearer(authorization: Optional[str]) -> Optional[str]:
if not authorization:
return None
parts = authorization.strip().split()
if len(parts) == 2 and parts[0].lower() == "bearer":
return parts[1]
return None
def _key_id(token: str) -> str:
return token[-6:] if len(token) >= 6 else token
def require_client_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
) -> Optional[Principal]:
"""Gate OpenAI-style endpoints with an API key.
Backward compatible: if no client keys are configured, this is a no-op.
Supported headers:
- Authorization: Bearer <key>
- X-Api-Key: <key>
"""
cfg = getattr(request.app.state, "cfg", None)
keys: Sequence[str] = []
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
keys = list(cfg.auth.client_api_keys or [])
if not keys:
return None # auth disabled
token = _extract_bearer(authorization) or x_api_key
if not token or token not in keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid API key.",
headers={"WWW-Authenticate": "Bearer"},
)
return Principal(kind="client", key_id=_key_id(token))
def require_node_auth(
request: Request,
authorization: Optional[str] = Header(default=None),
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
) -> Optional[Principal]:
"""Gate node registration/heartbeat endpoints with a node key.
Backward compatible: if no node keys are configured, this is a no-op.
Supported headers:
- Authorization: Bearer <node_key>
- X-RoleMesh-Node-Key: <node_key>
"""
cfg = getattr(request.app.state, "cfg", None)
keys: Sequence[str] = []
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
keys = list(cfg.auth.node_api_keys or [])
if not keys:
return None # auth disabled
token = _extract_bearer(authorization) or x_rolemesh_node_key
if not token or token not in keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing or invalid node key.",
headers={"WWW-Authenticate": "Bearer"},
)
return Principal(kind="node", key_id=_key_id(token))

View File

@ -12,6 +12,17 @@ class GatewayConfig(BaseModel):
port: int = 8000 port: int = 8000
class AuthConfig(BaseModel):
# API keys for clients calling OpenAI-compatible endpoints.
# If empty, client auth is disabled (scaffold/back-compat mode).
client_api_keys: List[str] = Field(default_factory=list)
# API keys for node agents calling register/heartbeat endpoints.
# If empty, node auth is disabled (scaffold/back-compat mode).
node_api_keys: List[str] = Field(default_factory=list)
ModelType = Literal["proxy", "discovered"] ModelType = Literal["proxy", "discovered"]
@ -39,6 +50,7 @@ class Config(BaseModel):
version: int = 1 version: int = 1
default_model: str default_model: str
gateway: GatewayConfig = Field(default_factory=GatewayConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig)
auth: AuthConfig = Field(default_factory=AuthConfig)
models: Dict[str, ModelEntry] = Field(default_factory=dict) models: Dict[str, ModelEntry] = Field(default_factory=dict)

View File

@ -33,5 +33,8 @@ class NodeAgentConfig(BaseModel):
dispatcher_roles: List[str] = Field(default_factory=list) dispatcher_roles: List[str] = Field(default_factory=list)
heartbeat_interval_sec: float = 5.0 heartbeat_interval_sec: float = 5.0
# Optional auth key presented to the dispatcher for /v1/nodes/* endpoints
dispatcher_node_key: Optional[str] = None
# llama-server binary name/path # llama-server binary name/path
llama_server_bin: str = "llama-server" llama_server_bin: str = "llama-server"

View File

@ -127,7 +127,10 @@ async def _heartbeat_loop(app: FastAPI) -> None:
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv], "metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
} }
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat" url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
await http.post(url, json=payload) headers = {}
if cfg.dispatcher_node_key:
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
await http.post(url, json=payload, headers=headers)
except Exception: except Exception:
pass pass
await asyncio.sleep(cfg.heartbeat_interval_sec) await asyncio.sleep(cfg.heartbeat_interval_sec)