mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-10-26 22:01:59 -06:00
Additional Auth Changes
This commit is contained in:
@@ -8,6 +8,7 @@ import jwt
|
||||
from flask import g, jsonify, request
|
||||
|
||||
from Modules.auth.dpop import DPoPValidator, DPoPVerificationError, DPoPReplayError
|
||||
from Modules.auth.rate_limit import SlidingWindowRateLimiter
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,11 +26,18 @@ class DeviceAuthError(Exception):
|
||||
status_code = 401
|
||||
error_code = "unauthorized"
|
||||
|
||||
def __init__(self, message: str = "unauthorized", *, status_code: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "unauthorized",
|
||||
*,
|
||||
status_code: Optional[int] = None,
|
||||
retry_after: Optional[float] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.retry_after = retry_after
|
||||
|
||||
|
||||
class DeviceAuthManager:
|
||||
@@ -40,11 +48,13 @@ class DeviceAuthManager:
|
||||
jwt_service,
|
||||
dpop_validator: Optional[DPoPValidator],
|
||||
log: Callable[[str, str], None],
|
||||
rate_limiter: Optional[SlidingWindowRateLimiter] = None,
|
||||
) -> None:
|
||||
self._db_conn_factory = db_conn_factory
|
||||
self._jwt_service = jwt_service
|
||||
self._dpop_validator = dpop_validator
|
||||
self._log = log
|
||||
self._rate_limiter = rate_limiter
|
||||
|
||||
def authenticate(self) -> DeviceAuthContext:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
@@ -67,6 +77,15 @@ class DeviceAuthManager:
|
||||
if not guid or not fingerprint or token_version <= 0:
|
||||
raise DeviceAuthError("invalid_claims")
|
||||
|
||||
if self._rate_limiter:
|
||||
decision = self._rate_limiter.check(f"fp:{fingerprint}", 60, 60.0)
|
||||
if not decision.allowed:
|
||||
raise DeviceAuthError(
|
||||
"rate_limited",
|
||||
status_code=429,
|
||||
retry_after=decision.retry_after,
|
||||
)
|
||||
|
||||
conn = self._db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
@@ -138,6 +157,12 @@ def require_device_auth(manager: DeviceAuthManager):
|
||||
except DeviceAuthError as exc:
|
||||
response = jsonify({"error": exc.message})
|
||||
response.status_code = exc.status_code
|
||||
retry_after = getattr(exc, "retry_after", None)
|
||||
if retry_after:
|
||||
try:
|
||||
response.headers["Retry-After"] = str(max(1, int(retry_after)))
|
||||
except Exception:
|
||||
response.headers["Retry-After"] = "1"
|
||||
return response
|
||||
|
||||
g.device_auth = ctx
|
||||
|
||||
@@ -26,6 +26,7 @@ def register(
|
||||
ip_rate_limiter: SlidingWindowRateLimiter,
|
||||
fp_rate_limiter: SlidingWindowRateLimiter,
|
||||
nonce_cache: NonceCache,
|
||||
script_signer,
|
||||
) -> None:
|
||||
blueprint = Blueprint("enrollment", __name__)
|
||||
|
||||
@@ -42,6 +43,14 @@ def register(
|
||||
addr = request.remote_addr or "unknown"
|
||||
return addr.strip()
|
||||
|
||||
def _signing_key_b64() -> str:
|
||||
if not script_signer:
|
||||
return ""
|
||||
try:
|
||||
return script_signer.public_base64_spki()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
def _rate_limited(key: str, limiter: SlidingWindowRateLimiter, limit: int, window_s: float):
|
||||
decision = limiter.check(key, limit, window_s)
|
||||
if not decision.allowed:
|
||||
@@ -312,6 +321,7 @@ def register(
|
||||
"server_nonce": server_nonce_b64,
|
||||
"poll_after_ms": 3000,
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
log("server", f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}")
|
||||
return jsonify(response)
|
||||
@@ -466,6 +476,7 @@ def register(
|
||||
"refresh_token": refresh_info["token"],
|
||||
"token_type": "Bearer",
|
||||
"server_certificate": _load_tls_bundle(tls_bundle_path),
|
||||
"signing_key": _signing_key_b64(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
|
||||
import eventlet
|
||||
@@ -25,7 +25,9 @@ def start_prune_job(
|
||||
|
||||
|
||||
def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None]) -> None:
|
||||
now_iso = datetime.now(tz=timezone.utc).isoformat()
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
now_iso = now.isoformat()
|
||||
stale_before = (now - timedelta(hours=24)).isoformat()
|
||||
conn = db_conn_factory()
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
@@ -55,7 +57,7 @@ def _run_once(db_conn_factory: Callable[[], any], log: Callable[[str, str], None
|
||||
OR created_at < ?
|
||||
)
|
||||
""",
|
||||
(now_iso, now_iso, now_iso),
|
||||
(now_iso, now_iso, stale_before),
|
||||
)
|
||||
approvals_marked = cur.rowcount or 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user