Files
Borealis-Github-Replica/Data/Agent/Roles/role_WireGuardTunnel.py
2026-01-11 20:53:09 -07:00

414 lines
16 KiB
Python

# ======================================================
# Data\Agent\Roles\role_WireGuardTunnel.py
# Description: WireGuard client lifecycle for outbound reverse VPN (Windows) with host-only /32 routing.
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard client role (Windows) for reverse VPN tunnels.
This role prepares the WireGuard client config, manages a single active
session, enforces idle teardown, and logs lifecycle events to
Agent/Logs/VPN_Tunnel/tunnel.log. It binds to Engine Socket.IO events
(`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop
the client session with the issued config/token.
"""
from __future__ import annotations
import base64
import json
import os
import subprocess
import threading
import time
from pathlib import Path
from typing import Any, Dict, Optional
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
try:
from signature_utils import verify_and_store_script_signature
except Exception: # pragma: no cover - fallback for runtime path issues
import sys
from pathlib import Path as _Path
base_dir = _Path(__file__).resolve().parents[1]
if str(base_dir) not in sys.path:
sys.path.insert(0, str(base_dir))
from signature_utils import verify_and_store_script_signature
ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"]
def _log_path() -> Path:
root = Path(__file__).resolve().parents[2] / "Logs" / "VPN_Tunnel"
root.mkdir(parents=True, exist_ok=True)
return root / "tunnel.log"
def _write_log(message: str) -> None:
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
try:
_log_path().open("a", encoding="utf-8").write(f"[{ts}] [wg-client] {message}\n")
except Exception:
pass
def _encode_key(raw: bytes) -> str:
return base64.b64encode(raw).decode("ascii").strip()
def _generate_client_keys(root: Path) -> Dict[str, str]:
root.mkdir(parents=True, exist_ok=True)
priv_path = root / "client_private.key"
pub_path = root / "client_public.key"
if priv_path.is_file() and pub_path.is_file():
try:
private_key = priv_path.read_text(encoding="utf-8").strip()
public_key = pub_path.read_text(encoding="utf-8").strip()
if private_key and public_key:
return {"private": private_key, "public": public_key}
except Exception:
pass
key = x25519.X25519PrivateKey.generate()
priv = _encode_key(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
)
pub = _encode_key(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
)
try:
priv_path.write_text(priv, encoding="utf-8")
pub_path.write_text(pub, encoding="utf-8")
except Exception:
_write_log("Failed to persist WireGuard client keys.")
return {"private": priv, "public": pub}
class SessionConfig:
def __init__(
self,
*,
token: Dict[str, Any],
virtual_ip: str,
allowed_ips: str,
endpoint: str,
server_public_key: str,
allowed_ports: str,
idle_seconds: int = 900,
preshared_key: Optional[str] = None,
client_private_key: Optional[str] = None,
client_public_key: Optional[str] = None,
) -> None:
self.token = token
self.virtual_ip = virtual_ip
self.allowed_ips = allowed_ips
self.endpoint = endpoint
self.server_public_key = server_public_key
self.allowed_ports = allowed_ports
self.idle_seconds = idle_seconds
self.preshared_key = preshared_key
self.client_private_key = client_private_key
self.client_public_key = client_public_key
class WireGuardClient:
def __init__(self) -> None:
base = Path(__file__).resolve().parents[2]
self.cert_root = base / "Borealis" / "Certificates" / "VPN_Client"
self.temp_root = base / "Borealis" / "Temp"
self.temp_root.mkdir(parents=True, exist_ok=True)
self.conf_path = self.temp_root / "borealis-wg-client.conf"
self.service_name = "borealis-wg-client"
self.session: Optional[SessionConfig] = None
self.idle_deadline: Optional[float] = None
self._idle_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._client_keys = _generate_client_keys(self.cert_root)
self._wg_exe = self._resolve_wireguard_exe()
def _resolve_wireguard_exe(self) -> str:
candidates = [
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
"wireguard.exe",
]
for candidate in candidates:
if Path(candidate).is_file():
return candidate
return "wireguard.exe"
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
payload = dict(token or {})
signature = payload.pop("signature", None)
signing_key = payload.pop("signing_key", None)
sig_alg = payload.pop("sig_alg", None)
required = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required if field not in token or token[field] in ("", None)]
if missing:
raise ValueError(f"Missing token fields: {', '.join(missing)}")
try:
exp = float(payload["expires_at"])
except Exception:
raise ValueError("Invalid token expiry")
if exp <= time.time():
raise ValueError("Token expired")
try:
port = int(payload["port"])
except Exception:
raise ValueError("Invalid token port")
if port < 1 or port > 65535:
raise ValueError("Invalid token port")
if not signature:
if sig_alg or signing_key:
raise ValueError("Token signature missing")
stored_key = None
if signing_client is not None and hasattr(signing_client, "load_server_signing_key"):
try:
stored_key = signing_client.load_server_signing_key()
except Exception:
stored_key = None
if isinstance(stored_key, str) and stored_key.strip():
raise ValueError("Token signature missing")
return
if signature:
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
raise ValueError("Unsupported token signature algorithm")
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key):
raise ValueError("Token signature invalid")
def _run(self, args: list[str]) -> tuple[int, str, str]:
try:
proc = subprocess.run(args, capture_output=True, text=True, check=False)
return proc.returncode, proc.stdout.strip(), proc.stderr.strip()
except Exception as exc: # pragma: no cover - runtime guard
return 1, "", str(exc)
def _render_config(self, session: SessionConfig) -> str:
private_key = session.client_private_key or self._client_keys["private"]
lines = [
"[Interface]",
f"PrivateKey = {private_key}",
f"Address = {session.virtual_ip}",
"",
"[Peer]",
f"PublicKey = {session.server_public_key}",
f"AllowedIPs = {session.allowed_ips}",
f"Endpoint = {session.endpoint}",
"PersistentKeepalive = 20",
]
if session.preshared_key:
lines.append(f"PresharedKey = {session.preshared_key}")
return "\n".join(lines)
def _start_idle_monitor(self) -> None:
self._stop_event.clear()
def _loop() -> None:
while not self._stop_event.is_set():
if self.idle_deadline and time.time() >= self.idle_deadline:
_write_log("Idle timeout reached; stopping WireGuard session.")
self.stop_session(reason="idle_timeout")
return
time.sleep(5)
t = threading.Thread(target=_loop, daemon=True)
t.start()
self._idle_thread = t
def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None:
if self.session:
_write_log("Rejecting start_session: existing session already active.")
return
try:
self._validate_token(session.token, signing_client=signing_client)
except Exception as exc:
_write_log(f"Refusing to start WireGuard session: {exc}")
return
rendered = self._render_config(session)
self.conf_path.write_text(rendered, encoding="utf-8")
_write_log(f"Rendered WireGuard client config to {self.conf_path}")
# Pre-stop any orphaned tunnel service using the same name
self.stop_session(reason="preflight", ignore_missing=True)
code, out, err = self._run([self._wg_exe, "/installtunnelservice", str(self.conf_path)])
if code != 0:
_write_log(f"Failed to install WireGuard client tunnel: code={code} err={err}")
return
self.session = session
self.idle_deadline = time.time() + max(60, session.idle_seconds)
_write_log("WireGuard client session started; idle timer armed.")
self._start_idle_monitor()
def stop_session(self, reason: str = "stop", ignore_missing: bool = False) -> None:
code, out, err = self._run([self._wg_exe, "/uninstalltunnelservice", self.service_name])
if code != 0:
if not ignore_missing:
_write_log(f"Failed to uninstall WireGuard client tunnel: code={code} err={err}")
else:
_write_log(f"WireGuard client session stopped (reason={reason}).")
self.session = None
self.idle_deadline = None
self._stop_event.set()
def bump_activity(self) -> None:
if self.session and self.idle_deadline:
self.idle_deadline = time.time() + max(60, self.session.idle_seconds)
_write_log("WireGuard client activity bump; idle timer reset.")
client = WireGuardClient()
def _parse_allowed_ips(value: Any, fallback: Optional[str]) -> Optional[str]:
if isinstance(value, list):
if not value:
return fallback
return str(value[0])
if isinstance(value, str) and value.strip():
return value.strip()
return fallback
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except Exception:
return default
class Role:
def __init__(self, ctx) -> None:
self.ctx = ctx
self.client = client
hooks = getattr(ctx, "hooks", {}) or {}
self._log_hook = hooks.get("log_agent")
self._http_client_factory = hooks.get("http_client")
try:
self.client.stop_session(reason="agent_startup", ignore_missing=True)
except Exception:
self._log("Failed to preflight WireGuard session cleanup.", error=True)
def _log(self, message: str, *, error: bool = False) -> None:
if callable(self._log_hook):
try:
self._log_hook(message, fname="VPN_Tunnel/tunnel.log")
if error:
self._log_hook(message, fname="agent.error.log")
except Exception:
pass
_write_log(message)
def _http_client(self) -> Optional[Any]:
try:
if callable(self._http_client_factory):
return self._http_client_factory()
except Exception:
return None
return None
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
if not isinstance(payload, dict):
self._log("WireGuard start payload missing/invalid.", error=True)
return None
payload_agent_id = payload.get("agent_id") or payload.get("agent_guid")
if payload_agent_id:
if str(payload_agent_id).strip() != str(self.ctx.agent_id).strip():
return None
token = payload.get("token") or payload.get("orchestration_token")
if not isinstance(token, dict):
self._log("WireGuard start missing token payload.", error=True)
return None
virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip")
endpoint = payload.get("endpoint") or payload.get("server_endpoint")
server_public_key = payload.get("server_public_key") or payload.get("public_key")
engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip")
allowed_ips = _parse_allowed_ips(payload.get("allowed_ips"), engine_virtual_ip)
if not allowed_ips:
self._log("WireGuard start missing allowed_ips/engine_virtual_ip.", error=True)
return None
if "," in allowed_ips or allowed_ips.endswith("/0") or "/32" not in allowed_ips:
self._log("WireGuard allowed_ips must be a single /32.", error=True)
return None
if not virtual_ip or not endpoint or not server_public_key:
self._log("WireGuard start missing required fields.", error=True)
return None
if "/32" not in str(virtual_ip):
self._log("WireGuard virtual_ip must be /32.", error=True)
return None
idle_seconds = _coerce_int(payload.get("idle_seconds"), 900)
allowed_ports = payload.get("allowed_ports")
if isinstance(allowed_ports, list):
allowed_ports = ",".join(str(p) for p in allowed_ports)
allowed_ports = str(allowed_ports or "")
return SessionConfig(
token=token,
virtual_ip=str(virtual_ip),
allowed_ips=str(allowed_ips),
endpoint=str(endpoint),
server_public_key=str(server_public_key),
allowed_ports=allowed_ports,
idle_seconds=idle_seconds,
preshared_key=payload.get("preshared_key"),
client_private_key=payload.get("client_private_key"),
client_public_key=payload.get("client_public_key"),
)
def register_events(self) -> None:
sio = self.ctx.sio
@sio.on("vpn_tunnel_start")
async def _vpn_tunnel_start(payload):
session = self._build_session(payload)
if not session:
return
self._log("WireGuard start request received.")
self.client.start_session(session, signing_client=self._http_client())
@sio.on("vpn_tunnel_stop")
async def _vpn_tunnel_stop(payload):
reason = "server_stop"
if isinstance(payload, dict):
target_agent = payload.get("agent_id")
if target_agent and str(target_agent).strip() != str(self.ctx.agent_id).strip():
return
reason = payload.get("reason") or reason
self._log(f"WireGuard stop requested (reason={reason}).")
self.client.stop_session(reason=str(reason))
@sio.on("vpn_tunnel_activity")
async def _vpn_tunnel_activity(payload):
self.client.bump_activity()
def stop_all(self) -> None:
try:
self.client.stop_session(reason="agent_shutdown")
except Exception:
self._log("Failed to stop WireGuard client during shutdown.", error=True)