Update for auth step 1
This commit is contained in:
parent
b2e4518b3a
commit
d7042b4a2b
|
|
@ -5,6 +5,15 @@ gateway:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 8000
|
port: 8000
|
||||||
|
|
||||||
|
|
||||||
|
auth:
|
||||||
|
# Set these to enable auth. If empty, auth is disabled (scaffold/back-compat).
|
||||||
|
client_api_keys:
|
||||||
|
- "change-me-client-key-1"
|
||||||
|
node_api_keys:
|
||||||
|
- "change-me-node-key-1"
|
||||||
|
|
||||||
|
|
||||||
# Models may be:
|
# Models may be:
|
||||||
# - type: proxy (static URL to an OpenAI-compatible upstream)
|
# - type: proxy (static URL to an OpenAI-compatible upstream)
|
||||||
# - type: discovered (resolved from registered nodes by role)
|
# - type: discovered (resolved from registered nodes by role)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ listen_port: 8091
|
||||||
|
|
||||||
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
|
# Set to the dispatcher gateway URL if you want auto-registration/heartbeat.
|
||||||
dispatcher_base_url: "http://127.0.0.1:8080"
|
dispatcher_base_url: "http://127.0.0.1:8080"
|
||||||
|
# Optional auth key presented to dispatcher for /v1/nodes/* endpoints
|
||||||
|
dispatcher_node_key: "change-me-node-key-1"
|
||||||
dispatcher_roles: ["planner", "coder"]
|
dispatcher_roles: ["planner", "coder"]
|
||||||
heartbeat_interval_sec: 5
|
heartbeat_interval_sec: 5
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,9 @@ default_model: writer
|
||||||
gateway:
|
gateway:
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
port: 8000
|
port: 8000
|
||||||
|
auth:
|
||||||
|
client_api_keys: ["..."]
|
||||||
|
node_api_keys: ["..."]
|
||||||
models:
|
models:
|
||||||
<alias>:
|
<alias>:
|
||||||
type: proxy | discovered
|
type: proxy | discovered
|
||||||
|
|
@ -61,4 +64,10 @@ Nodes register to `POST /v1/nodes/register`:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Security is intentionally omitted in this scaffold — add API keys or mTLS if the gateway is exposed beyond localhost.
|
If `auth.client_api_keys` is set (non-empty), callers of `/v1/models` and `/v1/chat/completions` must provide an API key.
|
||||||
|
|
||||||
|
If `auth.node_api_keys` is set (non-empty), node agents calling `/v1/nodes/register` and `/v1/nodes/heartbeat` must provide a node key.
|
||||||
|
|
||||||
|
Supported headers:
|
||||||
|
- Clients: `Authorization: Bearer <key>` or `X-Api-Key: <key>`
|
||||||
|
- Nodes: `Authorization: Bearer <node_key>` or `X-RoleMesh-Node-Key: <node_key>`
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ This scaffold supports two patterns.
|
||||||
```bash
|
```bash
|
||||||
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
|
curl -sS -X POST http://GATEWAY:8000/v1/nodes/register \
|
||||||
-H 'Content-Type: application/json' \
|
-H 'Content-Type: application/json' \
|
||||||
|
-H 'X-RoleMesh-Node-Key: <node-key>' \
|
||||||
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
|
-d '{"node_id":"gpu-box-1","base_url":"http://10.0.0.12:8012","roles":["writer"]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,13 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request, Depends
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
from rolemesh_gateway.config import Config, DiscoveredModel, ProxyModel
|
||||||
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
from rolemesh_gateway.registry import NodeHeartbeat, NodeRegistration, Registry
|
||||||
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
from rolemesh_gateway.upstream import UpstreamClient, UpstreamError
|
||||||
|
from rolemesh_gateway.auth import require_client_auth, require_node_auth
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
@ -30,7 +31,7 @@ def _openai_error(message: str, code: str = "upstream_error", status_code: int =
|
||||||
|
|
||||||
|
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
async def list_models(request: Request) -> Dict[str, Any]:
|
async def list_models(request: Request, _=Depends(require_client_auth)) -> Dict[str, Any]:
|
||||||
cfg: Config = request.app.state.cfg
|
cfg: Config = request.app.state.cfg
|
||||||
registry: Registry = request.app.state.registry
|
registry: Registry = request.app.state.registry
|
||||||
|
|
||||||
|
|
@ -54,7 +55,7 @@ async def list_models(request: Request) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completions(request: Request) -> Any:
|
async def chat_completions(request: Request, _=Depends(require_client_auth)) -> Any:
|
||||||
cfg: Config = request.app.state.cfg
|
cfg: Config = request.app.state.cfg
|
||||||
upstream: UpstreamClient = request.app.state.upstream
|
upstream: UpstreamClient = request.app.state.upstream
|
||||||
registry: Registry = request.app.state.registry
|
registry: Registry = request.app.state.registry
|
||||||
|
|
@ -133,7 +134,7 @@ async def ready(request: Request) -> JSONResponse:
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/nodes/register")
|
@router.post("/v1/nodes/register")
|
||||||
async def register_node(request: Request) -> Dict[str, Any]:
|
async def register_node(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Allow a remote machine to register which roles it serves.
|
Allow a remote machine to register which roles it serves.
|
||||||
|
|
||||||
|
|
@ -148,7 +149,7 @@ async def register_node(request: Request) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/nodes/heartbeat")
|
@router.post("/v1/nodes/heartbeat")
|
||||||
async def nodes_heartbeat(request: Request) -> Dict[str, Any]:
|
async def nodes_heartbeat(request: Request, _=Depends(require_node_auth)) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Allow a node agent to push status/metrics updates.
|
Allow a node agent to push status/metrics updates.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from fastapi import Header, HTTPException, Request, status
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Principal:
|
||||||
|
kind: str # "client" | "node"
|
||||||
|
key_id: str # opaque identifier (e.g., last 6 chars)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bearer(authorization: Optional[str]) -> Optional[str]:
|
||||||
|
if not authorization:
|
||||||
|
return None
|
||||||
|
parts = authorization.strip().split()
|
||||||
|
if len(parts) == 2 and parts[0].lower() == "bearer":
|
||||||
|
return parts[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _key_id(token: str) -> str:
|
||||||
|
return token[-6:] if len(token) >= 6 else token
|
||||||
|
|
||||||
|
|
||||||
|
def require_client_auth(
|
||||||
|
request: Request,
|
||||||
|
authorization: Optional[str] = Header(default=None),
|
||||||
|
x_api_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||||
|
) -> Optional[Principal]:
|
||||||
|
"""Gate OpenAI-style endpoints with an API key.
|
||||||
|
|
||||||
|
Backward compatible: if no client keys are configured, this is a no-op.
|
||||||
|
|
||||||
|
Supported headers:
|
||||||
|
- Authorization: Bearer <key>
|
||||||
|
- X-Api-Key: <key>
|
||||||
|
"""
|
||||||
|
cfg = getattr(request.app.state, "cfg", None)
|
||||||
|
keys: Sequence[str] = []
|
||||||
|
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
|
||||||
|
keys = list(cfg.auth.client_api_keys or [])
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return None # auth disabled
|
||||||
|
|
||||||
|
token = _extract_bearer(authorization) or x_api_key
|
||||||
|
if not token or token not in keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing or invalid API key.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return Principal(kind="client", key_id=_key_id(token))
|
||||||
|
|
||||||
|
|
||||||
|
def require_node_auth(
|
||||||
|
request: Request,
|
||||||
|
authorization: Optional[str] = Header(default=None),
|
||||||
|
x_rolemesh_node_key: Optional[str] = Header(default=None, convert_underscores=False),
|
||||||
|
) -> Optional[Principal]:
|
||||||
|
"""Gate node registration/heartbeat endpoints with a node key.
|
||||||
|
|
||||||
|
Backward compatible: if no node keys are configured, this is a no-op.
|
||||||
|
|
||||||
|
Supported headers:
|
||||||
|
- Authorization: Bearer <node_key>
|
||||||
|
- X-RoleMesh-Node-Key: <node_key>
|
||||||
|
"""
|
||||||
|
cfg = getattr(request.app.state, "cfg", None)
|
||||||
|
keys: Sequence[str] = []
|
||||||
|
if cfg is not None and hasattr(cfg, "auth") and cfg.auth is not None:
|
||||||
|
keys = list(cfg.auth.node_api_keys or [])
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return None # auth disabled
|
||||||
|
|
||||||
|
token = _extract_bearer(authorization) or x_rolemesh_node_key
|
||||||
|
if not token or token not in keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Missing or invalid node key.",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
return Principal(kind="node", key_id=_key_id(token))
|
||||||
|
|
@ -12,6 +12,17 @@ class GatewayConfig(BaseModel):
|
||||||
port: int = 8000
|
port: int = 8000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AuthConfig(BaseModel):
|
||||||
|
# API keys for clients calling OpenAI-compatible endpoints.
|
||||||
|
# If empty, client auth is disabled (scaffold/back-compat mode).
|
||||||
|
client_api_keys: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
# API keys for node agents calling register/heartbeat endpoints.
|
||||||
|
# If empty, node auth is disabled (scaffold/back-compat mode).
|
||||||
|
node_api_keys: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
ModelType = Literal["proxy", "discovered"]
|
ModelType = Literal["proxy", "discovered"]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,6 +50,7 @@ class Config(BaseModel):
|
||||||
version: int = 1
|
version: int = 1
|
||||||
default_model: str
|
default_model: str
|
||||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||||
|
auth: AuthConfig = Field(default_factory=AuthConfig)
|
||||||
models: Dict[str, ModelEntry] = Field(default_factory=dict)
|
models: Dict[str, ModelEntry] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,5 +33,8 @@ class NodeAgentConfig(BaseModel):
|
||||||
dispatcher_roles: List[str] = Field(default_factory=list)
|
dispatcher_roles: List[str] = Field(default_factory=list)
|
||||||
heartbeat_interval_sec: float = 5.0
|
heartbeat_interval_sec: float = 5.0
|
||||||
|
|
||||||
|
# Optional auth key presented to the dispatcher for /v1/nodes/* endpoints
|
||||||
|
dispatcher_node_key: Optional[str] = None
|
||||||
|
|
||||||
# llama-server binary name/path
|
# llama-server binary name/path
|
||||||
llama_server_bin: str = "llama-server"
|
llama_server_bin: str = "llama-server"
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,10 @@ async def _heartbeat_loop(app: FastAPI) -> None:
|
||||||
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
"metrics": [m.__dict__ | {"device": m.device.__dict__} for m in inv],
|
||||||
}
|
}
|
||||||
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
url = str(cfg.dispatcher_base_url).rstrip("/") + "/v1/nodes/heartbeat"
|
||||||
await http.post(url, json=payload)
|
headers = {}
|
||||||
|
if cfg.dispatcher_node_key:
|
||||||
|
headers["X-RoleMesh-Node-Key"] = cfg.dispatcher_node_key
|
||||||
|
await http.post(url, json=payload, headers=headers)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
await asyncio.sleep(cfg.heartbeat_interval_sec)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue