Additional Changes to VPN Tunneling

This commit is contained in:
2026-01-11 19:02:53 -07:00
parent 6ceb59f717
commit df14a1e26a
18 changed files with 681 additions and 175 deletions

View File

@@ -117,7 +117,7 @@ def _rotate_daily(path: Path) -> None:
pass
_QUIET_SERVICE_LOGS = {"scheduled_jobs"}
_QUIET_SERVICE_LOGS = {"scheduled_jobs", "device_enrollment"}
def _make_service_logger(base: Path, logger: logging.Logger) -> Callable[[str, str, Optional[str]], None]:

View File

@@ -12,12 +12,14 @@
from __future__ import annotations
import os
from urllib.parse import urlsplit
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...VPN import VpnTunnelService
from ...VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters
@@ -63,7 +65,22 @@ def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
if service is None:
manager = getattr(adapters.context, "wireguard_server_manager", None)
if manager is None:
raise RuntimeError("wireguard_manager_unavailable")
try:
manager = WireGuardServerManager(
WireGuardServerConfig(
port=adapters.context.wireguard_port,
engine_virtual_ip=adapters.context.wireguard_engine_virtual_ip,
peer_network=adapters.context.wireguard_peer_network,
private_key_path=Path(adapters.context.wireguard_server_private_key_path),
public_key_path=Path(adapters.context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(adapters.context.wireguard_acl_allowlist_windows),
log_path=Path(adapters.context.vpn_tunnel_log_path),
)
)
adapters.context.wireguard_server_manager = manager
except Exception as exc:
adapters.context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
raise RuntimeError("wireguard_manager_unavailable") from exc
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
@@ -86,6 +103,20 @@ def _normalize_text(value: Any) -> str:
return ""
def _infer_endpoint_host(req) -> str:
forwarded = (req.headers.get("X-Forwarded-Host") or req.headers.get("X-Original-Host") or "").strip()
host = forwarded.split(",")[0].strip() if forwarded else (req.host or "").strip()
if not host:
return ""
try:
parsed = urlsplit(f"//{host}")
if parsed.hostname:
return parsed.hostname
except Exception:
return host
return host
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
@@ -107,10 +138,15 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
try:
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
except RuntimeError as exc:
endpoint_host = _infer_endpoint_host(request)
payload = tunnel_service.connect(
agent_id=agent_id,
operator_id=operator_id,
endpoint_host=endpoint_host,
)
except Exception as exc:
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
return jsonify({"error": "connect_failed"}), 500
return jsonify({"error": "connect_failed", "detail": str(exc)}), 500
return jsonify(payload), 200

View File

@@ -73,6 +73,9 @@ def register(
except Exception:
return ""
def _enrollment_log(message: str, context_hint: Optional[str] = None) -> None:
log("device_enrollment", message, context_hint)
def _rate_limited(
key: str,
limiter: SlidingWindowRateLimiter,
@@ -82,8 +85,7 @@ def register(
):
decision = limiter.check(key, limit, window_s)
if not decision.allowed:
log(
"server",
_enrollment_log(
f"enrollment rate limited key={key} limit={limit}/{window_s}s retry_after={decision.retry_after:.2f}",
context_hint,
)
@@ -93,6 +95,9 @@ def register(
return response
return None
def _poll_log(message: str, context_hint: Optional[str] = None) -> None:
_enrollment_log(message, context_hint)
def _load_install_code(cur: sqlite3.Cursor, code_value: str) -> Optional[Dict[str, Any]]:
cur.execute(
"""
@@ -347,8 +352,7 @@ def register(
agent_pubkey_b64 = payload.get("agent_pubkey")
client_nonce_b64 = payload.get("client_nonce")
log(
"server",
_enrollment_log(
"enrollment request received "
f"ip={remote} hostname={hostname or '<missing>'} code_mask={_mask_code(enrollment_code)} "
f"pubkey_len={len(agent_pubkey_b64 or '')} nonce_len={len(client_nonce_b64 or '')}",
@@ -356,35 +360,35 @@ def register(
)
if not hostname:
log("server", f"enrollment rejected missing_hostname ip={remote}", context_hint)
_enrollment_log(f"enrollment rejected missing_hostname ip={remote}", context_hint)
return jsonify({"error": "hostname_required"}), 400
if not enrollment_code:
log("server", f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_code ip={remote} host={hostname}", context_hint)
return jsonify({"error": "enrollment_code_required"}), 400
if not isinstance(agent_pubkey_b64, str):
log("server", f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "agent_pubkey_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected missing_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
try:
agent_pubkey_der = crypto_keys.spki_der_from_base64(agent_pubkey_b64)
except Exception:
log("server", f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected invalid_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
if len(agent_pubkey_der) < 10:
log("server", f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected short_pubkey ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_agent_pubkey"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected invalid_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
if len(client_nonce_bytes) < 16:
log("server", f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
_enrollment_log(f"enrollment rejected short_nonce ip={remote} host={hostname}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
fingerprint = crypto_keys.fingerprint_from_spki_der(agent_pubkey_der)
@@ -398,8 +402,7 @@ def register(
install_code = _load_install_code(cur, enrollment_code)
site_id = install_code.get("site_id") if install_code else None
if site_id is None:
log(
"server",
_enrollment_log(
"enrollment request rejected missing_site_binding "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -407,8 +410,7 @@ def register(
return jsonify({"error": "invalid_enrollment_code"}), 400
cur.execute("SELECT 1 FROM sites WHERE id = ?", (site_id,))
if cur.fetchone() is None:
log(
"server",
_enrollment_log(
"enrollment request rejected missing_site_owner "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -416,8 +418,7 @@ def register(
return jsonify({"error": "invalid_enrollment_code"}), 400
valid_code, reuse_guid = _install_code_valid(install_code, fingerprint, cur)
if not valid_code:
log(
"server",
_enrollment_log(
"enrollment request invalid_code "
f"host={hostname} fingerprint={fingerprint[:12]} code_mask={_mask_code(enrollment_code)}",
context_hint,
@@ -509,8 +510,7 @@ def register(
"server_certificate": _load_tls_bundle(tls_bundle_path),
"signing_key": _signing_key_b64(),
}
log(
"server",
_enrollment_log(
f"enrollment request queued fingerprint={fingerprint[:12]} host={hostname} ip={remote}",
context_hint,
)
@@ -524,8 +524,7 @@ def register(
proof_sig_b64 = payload.get("proof_sig")
context_hint = _canonical_context(request.headers.get(AGENT_CONTEXT_HEADER))
log(
"server",
_poll_log(
"enrollment poll received "
f"ref={approval_reference} client_nonce_len={len(client_nonce_b64 or '')}"
f" proof_sig_len={len(proof_sig_b64 or '')}",
@@ -533,25 +532,25 @@ def register(
)
if not isinstance(approval_reference, str) or not approval_reference:
log("server", "enrollment poll rejected missing_reference", context_hint)
_poll_log("enrollment poll rejected missing_reference", context_hint)
return jsonify({"error": "approval_reference_required"}), 400
if not isinstance(client_nonce_b64, str):
log("server", f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll rejected missing_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "client_nonce_required"}), 400
if not isinstance(proof_sig_b64, str):
log("server", f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll rejected missing_sig ref={approval_reference}", context_hint)
return jsonify({"error": "proof_sig_required"}), 400
try:
client_nonce_bytes = base64.b64decode(client_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_client_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_client_nonce"}), 400
try:
proof_sig = base64.b64decode(proof_sig_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_sig ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof_sig"}), 400
conn = db_conn_factory()
@@ -569,7 +568,7 @@ def register(
)
row = cur.fetchone()
if not row:
log("server", f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll unknown_reference ref={approval_reference}", context_hint)
return jsonify({"status": "unknown"}), 404
(
@@ -589,13 +588,13 @@ def register(
) = row
if client_nonce_stored != client_nonce_b64:
log("server", f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll nonce_mismatch ref={approval_reference}", context_hint)
return jsonify({"error": "nonce_mismatch"}), 400
try:
server_nonce_bytes = base64.b64decode(server_nonce_b64, validate=True)
except Exception:
log("server", f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_server_nonce ref={approval_reference}", context_hint)
return jsonify({"error": "server_nonce_invalid"}), 400
message = server_nonce_bytes + approval_reference.encode("utf-8") + client_nonce_bytes
@@ -603,52 +602,47 @@ def register(
try:
public_key = serialization.load_der_public_key(agent_pubkey_der)
except Exception:
log("server", f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll pubkey_load_failed ref={approval_reference}", context_hint)
public_key = None
if public_key is None:
log("server", f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_pubkey ref={approval_reference}", context_hint)
return jsonify({"error": "agent_pubkey_invalid"}), 400
try:
public_key.verify(proof_sig, message)
except Exception:
log("server", f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
_poll_log(f"enrollment poll invalid_proof ref={approval_reference}", context_hint)
return jsonify({"error": "invalid_proof"}), 400
if status == "pending":
log(
"server",
_poll_log(
f"enrollment poll pending ref={approval_reference} host={hostname_claimed}"
f" fingerprint={fingerprint[:12]}",
context_hint,
)
return jsonify({"status": "pending", "poll_after_ms": 5000})
if status == "denied":
log(
"server",
_poll_log(
f"enrollment poll denied ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "denied", "reason": "operator_denied"})
if status == "expired":
log(
"server",
_poll_log(
f"enrollment poll expired ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "expired"})
if status == "completed":
log(
"server",
_poll_log(
f"enrollment poll already_completed ref={approval_reference} host={hostname_claimed}",
context_hint,
)
return jsonify({"status": "approved", "detail": "finalized"})
if status != "approved":
log(
"server",
_poll_log(
f"enrollment poll unexpected_status={status} ref={approval_reference}",
context_hint,
)
@@ -656,8 +650,7 @@ def register(
nonce_key = f"{approval_reference}:{base64.b64encode(proof_sig).decode('ascii')}"
if not nonce_cache.consume(nonce_key):
log(
"server",
_poll_log(
f"enrollment poll replay_detected ref={approval_reference} fingerprint={fingerprint[:12]}",
context_hint,
)
@@ -769,8 +762,7 @@ def register(
finally:
conn.close()
log(
"server",
_poll_log(
f"enrollment finalized guid={effective_guid} fingerprint={fingerprint[:12]} host={hostname_claimed}",
context_hint,
)