Update for auth step 1
This commit is contained in:
parent
b2e4518b3a
commit
d7042b4a2b
|
|
@ -5,6 +5,15 @@ gateway:
|
|||
host: 0.0.0.0
|
||||
port: 8000
|
||||
|
||||
|
||||
auth:
|
||||
# Set these to enable auth. If empty, auth is disabled (scaffold/back-compat).
|
||||
client_api_keys:
|
||||
- "change-me-client-key-1"
|
||||
node_api_keys:
|
||||
- "change-me-node-key-1"
|
||||
|
||||
|
||||
# Models may be:
|
||||
# - type: proxy (static URL to an OpenAI-compatible upstream)
|
||||
# - type: discovered (resolved from registered nodes by role)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ listen_port: 8091
|
|||
|
||||
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
|
||||
dispatcher_base_url: "http://127.0.0.1:8080"
|
||||
# Optional auth key presented to dispatcher for /v1/nodes/* endpoints
|
||||
dispatcher_node_key: "change-me-node-key-1"
|
||||
dispatcher_roles: ["planner", "coder"]
|
||||
heartbeat_interval_sec: 5
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ default_model: writer
|
|||
gateway:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
auth:
|
||||
client_api_keys: ["..."]
|
||||
node_api_keys: ["..."]
|
||||
models:
|
||||
<alias>:
|
||||
type: proxy | discovered
|
||||
|
|
@ -61,4 +64,10 @@ Nodes register to `POST /v1/nodes/register`:
|
|||
}
|
||||
```
|
||||
|
||||
Security is intentionally omitted in this scaffold — add API keys or mTLS if the gateway is exposed beyond localhost.
|
||||
If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key.
|
||||
|
||||
If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key.
|
||||
|
||||
Supported headers:
|
||||
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
|
||||
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ This scaffold supports two patterns.
|
|||
```bash
|
||||
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'X-RoleMesh-Node-Key: <node-key>' \
|
||||
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -5,12 +5,13 @@ import time
|
|||
import uuid
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi import APIRouter, Request, Depends
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
||||
from rolemesh_gateway.auth import require_client_auth, require_node_auth
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
|
|||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def list_models(request: Request) -> Dict[str, Any]:
|
||||
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
|
||||
cfg: Config = request.app.state.cfg
|
||||
registry: Registry = request.app.state.registry
|
||||
|
||||
|
|
@ -54,7 +55,7 @@ async def list_models(request: Request) -> Dict[str, Any]:
|
|||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request) -> Any:
|
||||
async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
|
||||
cfg: Config = request.app.state.cfg
|
||||
upstream: UpstreamClient = request.app.state.upstream
|
||||
registry: Registry = request.app.state.registry
|
||||
|
|
@ -133,7 +134,7 @@ async def ready(request: Request) -> JSONResponse:
|
|||
|
||||
|
||||
@router.post("/v1/nodes/register")
|
||||
async def register_node(request: Request) -> Dict[str, Any]:
|
||||
async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
||||
"""
|
||||
Allow a remote machine to register which roles it serves.
|
||||
|
||||
|
|
@ -148,7 +149,7 @@ async def register_node(request: Request) -> Dict[str, Any]:
|
|||
|
||||
|
||||
@router.post("/v1/nodes/heartbeat")
|
||||
async def nodes_heartbeat(request: Request) -> Dict[str, Any]:
|
||||
async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
||||
"""
|
||||
Allow a node agent to push status/metrics updates.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from fastapi import Header, HTTPException, Request, status
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Principal:
|
||||
kind: str # "client" | "node"
|
||||
key_id: str # opaque identifier (e.g., last 6 chars)
|
||||
|
||||
|
||||
def _extract_bearer(authorization: Optional[str]) -> Optional[str]:
|
||||
if not authorization:
|
||||
return None
|
||||
parts = authorization.strip().split()
|
||||
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def _key_id(token: str) -> str:
|
||||
return token[-6:] if len(token) >= 6 else token
|
||||
|
||||
|
||||
def require_client_auth(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||
) -> Optional[Principal]:
|
||||
"""Gate OpenAI-style endpoints with an API key.
|
||||
|
||||
Backward compatible: if no client keys are configured, this is a no-op.
|
||||
|
||||
Supported headers:
|
||||
- Authorization: Bearer <key>
|
||||
- X-Api-Key: <key>
|
||||
"""
|
||||
cfg = getattr(request.app.state, "cfg", None)
|
||||
keys: Sequence[str] = []
|
||||
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
|
||||
keys = list(cfg.auth.client_api_keys or [])
|
||||
|
||||
if not keys:
|
||||
return None # auth disabled
|
||||
|
||||
token = _extract_bearer(authorization) or x_api_key
|
||||
if not token or token not in keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing or invalid API key.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return Principal(kind="client", key_id=_key_id(token))
|
||||
|
||||
|
||||
def require_node_auth(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||
) -> Optional[Principal]:
|
||||
"""Gate node registration/heartbeat endpoints with a node key.
|
||||
|
||||
Backward compatible: if no node keys are configured, this is a no-op.
|
||||
|
||||
Supported headers:
|
||||
- Authorization: Bearer <node_key>
|
||||
- X-RoleMesh-Node-Key: <node_key>
|
||||
"""
|
||||
cfg = getattr(request.app.state, "cfg", None)
|
||||
keys: Sequence[str] = []
|
||||
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
|
||||
keys = list(cfg.auth.node_api_keys or [])
|
||||
|
||||
if not keys:
|
||||
return None # auth disabled
|
||||
|
||||
token = _extract_bearer(authorization) or x_rolemesh_node_key
|
||||
if not token or token not in keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing or invalid node key.",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
return Principal(kind="node", key_id=_key_id(token))
|
||||
|
|
@ -12,6 +12,17 @@ class GatewayConfig(BaseModel):
|
|||
port: int = 8000
|
||||
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
# API keys for clients calling OpenAI-compatible endpoints.
|
||||
# If empty, client auth is disabled (scaffold/back-compat mode).
|
||||
client_api_keys: List[str] = Field(default_factory=list)
|
||||
|
||||
# API keys for node agents calling register/heartbeat endpoints.
|
||||
# If empty, node auth is disabled (scaffold/back-compat mode).
|
||||
node_api_keys: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
ModelType = Literal["proxy", "discovered"]
|
||||
|
||||
|
||||
|
|
@ -39,6 +50,7 @@ class Config(BaseModel):
|
|||
version: int = 1
|
||||
default_model: str
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
auth: AuthConfig = Field(default_factory=AuthConfig)
|
||||
models: Dict[str, ModelEntry] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,5 +33,8 @@ class NodeAgentConfig(BaseModel):
|
|||
dispatcher_roles: List[str] = Field(default_factory=list)
|
||||
heartbeat_interval_sec: float = 5.0
|
||||
|
||||
# Optional auth key presented to the dispatcher for /v1/nodes/* endpoints
|
||||
dispatcher_node_key: Optional[str] = None
|
||||
|
||||
# llama-server binary name/path
|
||||
llama_server_bin: str = "llama-server"
|
||||
|
|
|
|||
|
|
@ -127,7 +127,10 @@ async def _heartbeat_loop(app: FastAPI) -> None:
|
|||
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
||||
}
|
||||
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
||||
await http.post(url, json=payload)
|
||||
headers = {}
|
||||
if cfg.dispatcher_node_key:
|
||||
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
|
||||
await http.post(url, json=payload, headers=headers)
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
||||
|
|
|
|||
Loading…
Reference in New Issue