# ====================================================== # Data\Agent\Roles\role_WireGuardTunnel.py # Description: WireGuard client lifecycle for outbound reverse VPN (Windows) with host-only /32 routing. # # API Endpoints (if applicable): None # ====================================================== """WireGuard client role (Windows) for reverse VPN tunnels. This role prepares the WireGuard client config, manages a single active session, enforces idle teardown, and logs lifecycle events to Agent/Logs/VPN_Tunnel/tunnel.log. It binds to Engine Socket.IO events (`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop the client session with the issued config/token. """ from __future__ import annotations import base64 import json import os import subprocess import threading import time import re from pathlib import Path from typing import Any, Dict, Optional try: import winreg # type: ignore except Exception: # pragma: no cover - non-Windows guard winreg = None from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 try: from signature_utils import verify_and_store_script_signature except Exception: # pragma: no cover - fallback for runtime path issues import sys from pathlib import Path as _Path base_dir = _Path(__file__).resolve().parents[1] if str(base_dir) not in sys.path: sys.path.insert(0, str(base_dir)) from signature_utils import verify_and_store_script_signature ROLE_NAME = "WireGuardTunnel" ROLE_CONTEXTS = ["system"] TUNNEL_NAME = "Borealis" TUNNEL_DISPLAY_NAME = "Borealis" TUNNEL_IDLE_ADDRESS = "169.254.255.254/32" def _log_path() -> Path: root = Path(__file__).resolve().parents[2] / "Logs" / "VPN_Tunnel" root.mkdir(parents=True, exist_ok=True) return root / "tunnel.log" def _write_log(message: str) -> None: ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) try: _log_path().open("a", encoding="utf-8").write(f"[{ts}] [wg-client] {message}\n") except Exception: pass def _encode_key(raw: bytes) -> str: return base64.b64encode(raw).decode("ascii").strip() def _generate_client_keys(root: Path) -> Dict[str, str]: root.mkdir(parents=True, exist_ok=True) priv_path = root / "client_private.key" pub_path = root / "client_public.key" if priv_path.is_file() and pub_path.is_file(): try: private_key = priv_path.read_text(encoding="utf-8").strip() public_key = pub_path.read_text(encoding="utf-8").strip() if private_key and public_key: return {"private": private_key, "public": public_key} except Exception: pass key = x25519.X25519PrivateKey.generate() priv = _encode_key( key.private_bytes( encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption(), ) ) pub = _encode_key( key.public_key().public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ) ) try: priv_path.write_text(priv, encoding="utf-8") pub_path.write_text(pub, encoding="utf-8") except Exception: _write_log("Failed to persist WireGuard client keys.") return {"private": priv, "public": pub} class SessionConfig: def __init__( self, *, token: Dict[str, Any], virtual_ip: str, allowed_ips: str, endpoint: str, server_public_key: str, allowed_ports: str, idle_seconds: int = 900, preshared_key: Optional[str] = None, client_private_key: Optional[str] = None, client_public_key: Optional[str] = None, ) -> None: self.token = token self.virtual_ip = virtual_ip self.allowed_ips = allowed_ips self.endpoint = endpoint self.server_public_key = server_public_key self.allowed_ports = allowed_ports self.idle_seconds = idle_seconds self.preshared_key = preshared_key self.client_private_key = client_private_key self.client_public_key = client_public_key class WireGuardClient: def __init__(self) -> None: base = Path(__file__).resolve().parents[2] self.cert_root = base / "Borealis" / "Certificates" / "VPN_Client" self.temp_root = base / "Borealis" / "Temp" self.temp_root.mkdir(parents=True, exist_ok=True) self.service_name = TUNNEL_NAME self.display_name = TUNNEL_DISPLAY_NAME self.conf_path = self._wireguard_config_path() self.session: Optional[SessionConfig] = None self.idle_deadline: Optional[float] = None self._idle_thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self._session_lock = threading.Lock() self._client_keys = _generate_client_keys(self.cert_root) self._wg_exe = self._resolve_wireguard_exe() self._last_install_already_present = False try: self._ensure_idle_service() except Exception: pass def _resolve_wireguard_exe(self) -> str: candidates = [ str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"), "wireguard.exe", ] for candidate in candidates: if Path(candidate).is_file(): return candidate return "wireguard.exe" def _service_id(self) -> str: return f"WireGuardTunnel${self.service_name}" def _service_reg_path(self) -> str: return f"SYSTEM\\CurrentControlSet\\Services\\{self._service_id()}" def _service_reg_exists(self) -> bool: if winreg is None: return False try: winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, self._service_reg_path()) return True except FileNotFoundError: return False except PermissionError: _write_log("WireGuard service registry check denied; treating as present.") return True except Exception as exc: _write_log(f"WireGuard service registry check failed: {exc}") return False def _service_image_path(self) -> Optional[str]: if winreg is None: return None try: key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, self._service_reg_path()) value, _ = winreg.QueryValueEx(key, "ImagePath") return str(value) if value else None except Exception: return None def _service_config_path(self) -> Optional[Path]: image_path = self._service_image_path() if not image_path: return None match = re.search(r'(?i)/tunnelservice\s+"([^"]+)"', image_path) if match: return Path(match.group(1)) match = re.search(r"(?i)/tunnelservice\s+(\S+)", image_path) if match: return Path(match.group(1)) return None def _wireguard_config_path(self) -> Path: settings_dir = self.temp_root.parent / "Settings" / "WireGuard" candidates = [ settings_dir, Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "Data" / "Configurations", Path(os.environ.get("ProgramData", "C:\\ProgramData")) / "Borealis" / "WireGuard" / "Configurations", self.temp_root, ] for config_dir in candidates: candidate = config_dir / f"{self.service_name}.conf" if candidate.is_file(): return candidate for config_dir in candidates: try: config_dir.mkdir(parents=True, exist_ok=True) return config_dir / f"{self.service_name}.conf" except Exception: continue return self.temp_root / f"{self.service_name}.conf" def _write_config(self, text: str) -> bool: return self._write_config_to(self.conf_path, text) def _write_config_to(self, path: Path, text: str) -> bool: try: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(text, encoding="ascii") return True except Exception as exc: _write_log(f"Failed to write WireGuard config at {path}: {exc}") return False def _render_idle_config(self) -> str: private_key = self._client_keys["private"] return "\n".join( [ "[Interface]", f"PrivateKey = {private_key}", f"Address = {TUNNEL_IDLE_ADDRESS}", "ListenPort = 0", ] ) def _service_exists(self) -> bool: code, _, _ = self._run(["sc.exe", "query", self._service_id()]) if code == 0: return True return self._service_reg_exists() def _install_service(self) -> bool: code, out, err = self._run([self._wg_exe, "/installtunnelservice", str(self.conf_path)]) self._last_install_already_present = False if code != 0: if "already installed and running" in err.lower(): self._last_install_already_present = True _write_log("WireGuard tunnel service already installed; skipping install.") return True if "access is denied" in err.lower(): _write_log("Failed to install WireGuard tunnel service: access denied; ensure agent runs elevated.") return False _write_log(f"Failed to install WireGuard tunnel service: code={code} err={err}") return False return True def _restart_service(self) -> bool: service_id = self._service_id() stop_code, _, stop_err = self._run(["sc.exe", "stop", service_id]) if stop_code != 0 and stop_err: _write_log(f"WireGuard stop service returned code={stop_code} err={stop_err}") time.sleep(1) start_code, _, start_err = self._run(["sc.exe", "start", service_id]) if start_code != 0 and start_err: _write_log(f"WireGuard start service returned code={start_code} err={start_err}") return start_code == 0 def _ensure_adapter_name(self) -> None: if self.service_name == self.display_name: return args = [ "netsh.exe", "interface", "set", "interface", f'name="{self.service_name}"', f'newname="{self.display_name}"', ] self._run(args) def _ensure_idle_service(self) -> None: if self._service_exists(): return if not Path(self._wg_exe).is_file(): return idle_config = self._render_idle_config() if not self._write_config(idle_config): return if self._install_service(): self._ensure_adapter_name() def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None: payload = dict(token or {}) signature = payload.pop("signature", None) signing_key = payload.pop("signing_key", None) sig_alg = payload.pop("sig_alg", None) required = ("agent_id", "tunnel_id", "expires_at", "port") missing = [field for field in required if field not in token or token[field] in ("", None)] if missing: raise ValueError(f"Missing token fields: {', '.join(missing)}") try: exp = float(payload["expires_at"]) except Exception: raise ValueError("Invalid token expiry") if exp <= time.time(): raise ValueError("Token expired") try: port = int(payload["port"]) except Exception: raise ValueError("Invalid token port") if port < 1 or port > 65535: raise ValueError("Invalid token port") if not signature: if sig_alg or signing_key: raise ValueError("Token signature missing") stored_key = None if signing_client is not None and hasattr(signing_client, "load_server_signing_key"): try: stored_key = signing_client.load_server_signing_key() except Exception: stored_key = None if isinstance(stored_key, str) and stored_key.strip(): raise ValueError("Token signature missing") return if signature: if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"): raise ValueError("Unsupported token signature algorithm") payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key): raise ValueError("Token signature invalid") def _run(self, args: list[str]) -> tuple[int, str, str]: try: proc = subprocess.run(args, capture_output=True, text=True, check=False) return proc.returncode, proc.stdout.strip(), proc.stderr.strip() except Exception as exc: # pragma: no cover - runtime guard return 1, "", str(exc) def _render_config(self, session: SessionConfig) -> str: private_key = session.client_private_key or self._client_keys["private"] lines = [ "[Interface]", f"PrivateKey = {private_key}", f"Address = {session.virtual_ip}", "", "[Peer]", f"PublicKey = {session.server_public_key}", f"AllowedIPs = {session.allowed_ips}", f"Endpoint = {session.endpoint}", "PersistentKeepalive = 20", ] if session.preshared_key: lines.append(f"PresharedKey = {session.preshared_key}") return "\n".join(lines) def _start_idle_monitor(self) -> None: self._stop_event.clear() def _loop() -> None: while not self._stop_event.is_set(): if self.idle_deadline and time.time() >= self.idle_deadline: _write_log("Idle timeout reached; stopping WireGuard session.") self.stop_session(reason="idle_timeout") return time.sleep(5) t = threading.Thread(target=_loop, daemon=True) t.start() self._idle_thread = t def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None: with self._session_lock: if self.session: _write_log("Rejecting start_session: existing session already active.") return try: self._validate_token(session.token, signing_client=signing_client) except Exception as exc: _write_log(f"Refusing to start WireGuard session: {exc}") return rendered = self._render_config(session) if not self._write_config(rendered): _write_log("Failed to write WireGuard client config.") return _write_log(f"Rendered WireGuard client config to {self.conf_path}") service_config_path = self._service_config_path() if service_config_path and service_config_path != self.conf_path: if self._write_config_to(service_config_path, rendered): _write_log(f"Rendered WireGuard client config to service path {service_config_path}") if not self._service_exists(): if not self._install_service(): return service_present = self._service_exists() if not service_present and self._last_install_already_present: _write_log("WireGuard tunnel service presence inferred from install response.") service_present = True if not service_present: _write_log("WireGuard tunnel service still missing after install attempt.") return self._restart_service() self._ensure_adapter_name() self.session = session self.idle_deadline = time.time() + max(60, session.idle_seconds) _write_log("WireGuard client session started; idle timer armed.") self._start_idle_monitor() def stop_session(self, reason: str = "stop", ignore_missing: bool = False) -> None: with self._session_lock: if not self._service_exists(): if not ignore_missing: _write_log("WireGuard tunnel service not found when stopping session.") self.session = None self.idle_deadline = None self._stop_event.set() return idle_config = self._render_idle_config() wrote_idle = self._write_config(idle_config) service_config_path = self._service_config_path() if service_config_path and service_config_path != self.conf_path: wrote_idle = self._write_config_to(service_config_path, idle_config) or wrote_idle if wrote_idle: self._restart_service() self._ensure_adapter_name() _write_log(f"WireGuard client session stopped (reason={reason}).") elif not ignore_missing: _write_log("Failed to write idle WireGuard config.") self.session = None self.idle_deadline = None self._stop_event.set() def bump_activity(self) -> None: if self.session and self.idle_deadline: self.idle_deadline = time.time() + max(60, self.session.idle_seconds) _write_log("WireGuard client activity bump; idle timer reset.") client = WireGuardClient() def _parse_allowed_ips(value: Any, fallback: Optional[str]) -> Optional[str]: if isinstance(value, list): if not value: return fallback return str(value[0]) if isinstance(value, str) and value.strip(): return value.strip() return fallback def _coerce_int(value: Any, default: int) -> int: try: return int(value) except Exception: return default class Role: def __init__(self, ctx) -> None: self.ctx = ctx self.client = client hooks = getattr(ctx, "hooks", {}) or {} self._log_hook = hooks.get("log_agent") self._http_client_factory = hooks.get("http_client") try: self.client.stop_session(reason="agent_startup", ignore_missing=True) except Exception: self._log("Failed to preflight WireGuard session cleanup.", error=True) def _log(self, message: str, *, error: bool = False) -> None: if callable(self._log_hook): try: self._log_hook(message, fname="VPN_Tunnel/tunnel.log") if error: self._log_hook(message, fname="agent.error.log") except Exception: pass _write_log(message) def _http_client(self) -> Optional[Any]: try: if callable(self._http_client_factory): return self._http_client_factory() except Exception: return None return None def _build_session(self, payload: Any) -> Optional[SessionConfig]: if not isinstance(payload, dict): self._log("WireGuard start payload missing/invalid.", error=True) return None payload_agent_id = payload.get("agent_id") or payload.get("agent_guid") if payload_agent_id: if str(payload_agent_id).strip() != str(self.ctx.agent_id).strip(): return None token = payload.get("token") or payload.get("orchestration_token") if not isinstance(token, dict): self._log("WireGuard start missing token payload.", error=True) return None virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip") endpoint = payload.get("endpoint") or payload.get("server_endpoint") server_public_key = payload.get("server_public_key") or payload.get("public_key") engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip") allowed_ips = _parse_allowed_ips(payload.get("allowed_ips"), engine_virtual_ip) if not allowed_ips: self._log("WireGuard start missing allowed_ips/engine_virtual_ip.", error=True) return None if "," in allowed_ips or allowed_ips.endswith("/0") or "/32" not in allowed_ips: self._log("WireGuard allowed_ips must be a single /32.", error=True) return None if not virtual_ip or not endpoint or not server_public_key: self._log("WireGuard start missing required fields.", error=True) return None if "/32" not in str(virtual_ip): self._log("WireGuard virtual_ip must be /32.", error=True) return None idle_seconds = _coerce_int(payload.get("idle_seconds"), 900) allowed_ports = payload.get("allowed_ports") if isinstance(allowed_ports, list): allowed_ports = ",".join(str(p) for p in allowed_ports) allowed_ports = str(allowed_ports or "") return SessionConfig( token=token, virtual_ip=str(virtual_ip), allowed_ips=str(allowed_ips), endpoint=str(endpoint), server_public_key=str(server_public_key), allowed_ports=allowed_ports, idle_seconds=idle_seconds, preshared_key=payload.get("preshared_key"), client_private_key=payload.get("client_private_key"), client_public_key=payload.get("client_public_key"), ) def register_events(self) -> None: sio = self.ctx.sio @sio.on("vpn_tunnel_start") async def _vpn_tunnel_start(payload): session = self._build_session(payload) if not session: return self._log("WireGuard start request received.") self.client.start_session(session, signing_client=self._http_client()) @sio.on("vpn_tunnel_stop") async def _vpn_tunnel_stop(payload): reason = "server_stop" if isinstance(payload, dict): target_agent = payload.get("agent_id") if target_agent and str(target_agent).strip() != str(self.ctx.agent_id).strip(): return reason = payload.get("reason") or reason self._log(f"WireGuard stop requested (reason={reason}).") self.client.stop_session(reason=str(reason)) @sio.on("vpn_tunnel_activity") async def _vpn_tunnel_activity(payload): self.client.bump_activity() def stop_all(self) -> None: try: self.client.stop_session(reason="agent_shutdown") except Exception: self._log("Failed to stop WireGuard client during shutdown.", error=True)