From 7dda62d9ee9732668cdbd3d185d961c86b96c103 Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Tue, 13 Jan 2026 22:22:31 -0700 Subject: [PATCH] Updates to WireGuard --- Data/Agent/Roles/role_WireGuardTunnel.py | 144 +++++++++++++++++++ Data/Engine/services/VPN/wireguard_server.py | 60 +++++++- 2 files changed, 199 insertions(+), 5 deletions(-) diff --git a/Data/Agent/Roles/role_WireGuardTunnel.py b/Data/Agent/Roles/role_WireGuardTunnel.py index b8b4242a..76cfdf5d 100644 --- a/Data/Agent/Roles/role_WireGuardTunnel.py +++ b/Data/Agent/Roles/role_WireGuardTunnel.py @@ -17,6 +17,7 @@ the client session with the issued config/token. from __future__ import annotations import base64 +import ipaddress import json import os import subprocess @@ -25,6 +26,7 @@ import time import re from pathlib import Path from typing import Any, Dict, Optional +from urllib.parse import urlsplit try: import winreg # type: ignore @@ -48,6 +50,7 @@ ROLE_NAME = "WireGuardTunnel" ROLE_CONTEXTS = ["system"] TUNNEL_NAME = "Borealis" TUNNEL_DISPLAY_NAME = "Borealis" +SERVICE_DISPLAY_NAME = "Borealis - WireGuard - Agent" TUNNEL_IDLE_ADDRESS = "169.254.255.254/32" @@ -140,6 +143,7 @@ class WireGuardClient: self.temp_root.mkdir(parents=True, exist_ok=True) self.service_name = TUNNEL_NAME self.display_name = TUNNEL_DISPLAY_NAME + self.service_display_name = SERVICE_DISPLAY_NAME self.conf_path = self._wireguard_config_path() self.session: Optional[SessionConfig] = None self.idle_deadline: Optional[float] = None @@ -295,6 +299,22 @@ class WireGuardClient: ] self._run(args) + def _ensure_service_display_name(self) -> None: + if not self.service_display_name: + return + if not self._service_exists(): + return + args = [ + "sc.exe", + "config", + self._service_id(), + "DisplayName=", + self.service_display_name, + ] + code, _, err = self._run(args) + if code != 0 and err: + _write_log(f"WireGuard service display name update failed: {err}") + def _ensure_idle_service(self) -> None: if self._service_exists(): return @@ -305,6 +325,7 @@ class WireGuardClient: return if self._install_service(): self._ensure_adapter_name() + self._ensure_service_display_name() def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None: payload = dict(token or {}) @@ -425,6 +446,7 @@ class WireGuardClient: self._restart_service() self._ensure_adapter_name() + self._ensure_service_display_name() self.session = session self.idle_deadline = time.time() + max(60, session.idle_seconds) @@ -449,6 +471,7 @@ class WireGuardClient: if wrote_idle: self._restart_service() self._ensure_adapter_name() + self._ensure_service_display_name() _write_log(f"WireGuard client session stopped (reason={reason}).") elif not ignore_missing: _write_log("Failed to write idle WireGuard config.") @@ -492,6 +515,98 @@ def _coerce_int(value: Any, default: int) -> int: return default +def _parse_endpoint_host(value: Optional[str]) -> Optional[str]: + if not value: + return None + text = str(value).strip() + if not text: + return None + if text.startswith("["): + match = re.match(r"^\[([^\]]+)\]", text) + if match: + return match.group(1) + host, sep, _ = text.rpartition(":") + if sep and host: + return host + return text + + +def _parse_endpoint_port(value: Optional[str]) -> Optional[int]: + if not value: + return None + text = str(value).strip() + if not text: + return None + if text.startswith("["): + match = re.match(r"^\[[^\]]+\]:(\d+)$", text) + if match: + try: + return int(match.group(1)) + except Exception: + return None + return None + _, sep, port = text.rpartition(":") + if sep and port.isdigit(): + try: + return int(port) + except Exception: + return None + return None + + +def _format_endpoint(host: str, port: Optional[int]) -> Optional[str]: + if not host: + return None + text = str(host).strip() + if not text: + return None + if ":" in text and not text.startswith("["): + text = f"[{text}]" + if port is None: + return text + return f"{text}:{port}" + + +def _parse_server_url_host(server_url: Optional[str]) -> Optional[str]: + if not server_url: + return None + try: + parsed = urlsplit(str(server_url).strip()) + if parsed.hostname: + return parsed.hostname.strip() + except Exception: + pass + text = str(server_url).strip() + if not text: + return None + if "://" in text: + text = text.split("://", 1)[1] + text = text.split("/", 1)[0].strip() + if text.startswith("[") and "]" in text: + text = text[1:text.index("]")] + if ":" in text: + host_part, port_part = text.rsplit(":", 1) + if port_part.isdigit(): + text = host_part + text = text.strip() + return text or None + + +def _is_loopback_host(host: Optional[str]) -> bool: + if not host: + return True + text = str(host).strip().lower() + if not text: + return True + if text == "localhost": + return True + try: + ip = ipaddress.ip_address(text) + return ip.is_loopback or ip.is_unspecified + except Exception: + return False + + class Role: def __init__(self, ctx) -> None: self.ctx = ctx @@ -499,6 +614,7 @@ class Role: hooks = getattr(ctx, "hooks", {}) or {} self._log_hook = hooks.get("log_agent") self._http_client_factory = hooks.get("http_client") + self._get_server_url = hooks.get("get_server_url") try: self.client.stop_session(reason="agent_startup", ignore_missing=True) except Exception: @@ -522,6 +638,33 @@ class Role: return None return None + def _resolve_endpoint(self, endpoint: Optional[str], token: Dict[str, Any]) -> Optional[str]: + server_url = None + if callable(self._get_server_url): + try: + server_url = self._get_server_url() + except Exception: + server_url = None + + server_host = _parse_server_url_host(server_url) + if not server_host: + return endpoint + + endpoint_host = _parse_endpoint_host(endpoint) + if endpoint_host and not _is_loopback_host(endpoint_host): + return endpoint + + endpoint_port = _parse_endpoint_port(endpoint) + if endpoint_port is None: + endpoint_port = _coerce_int(token.get("port"), 0) + if endpoint_port <= 0: + endpoint_port = None + + resolved = _format_endpoint(server_host, endpoint_port) + if resolved and endpoint and resolved != endpoint: + self._log(f"WireGuard endpoint override: {endpoint} -> {resolved}") + return resolved or endpoint + def _build_session(self, payload: Any) -> Optional[SessionConfig]: if not isinstance(payload, dict): self._log("WireGuard start payload missing/invalid.", error=True) @@ -539,6 +682,7 @@ class Role: virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip") endpoint = payload.get("endpoint") or payload.get("server_endpoint") + endpoint = self._resolve_endpoint(endpoint, token) server_public_key = payload.get("server_public_key") or payload.get("public_key") engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip") diff --git a/Data/Engine/services/VPN/wireguard_server.py b/Data/Engine/services/VPN/wireguard_server.py index 62211b74..e23f44c1 100644 --- a/Data/Engine/services/VPN/wireguard_server.py +++ b/Data/Engine/services/VPN/wireguard_server.py @@ -19,8 +19,8 @@ import base64 import ipaddress import logging import os +import re import subprocess -import tempfile import time from dataclasses import dataclass from logging.handlers import TimedRotatingFileHandler @@ -30,6 +30,7 @@ from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Uni from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 +from ... import config as engine_config def _build_logger(log_path: Path) -> logging.Logger: logger = logging.getLogger("borealis.engine.wireguard") @@ -72,9 +73,18 @@ class WireGuardServerManager: self._ensure_cert_dir() self.server_private_key, self.server_public_key = self._ensure_server_keys() self._service_name = "borealis-wg" - self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine" + self._service_display_name = "Borealis - WireGuard - Engine" + self._config_dir = self._resolve_config_dir() self._wireguard_exe = self._resolve_wireguard_exe() + def _resolve_config_dir(self) -> Path: + config_dir = engine_config.PROJECT_ROOT / "Engine" / "WireGuard" + try: + config_dir.mkdir(parents=True, exist_ok=True) + except Exception: + self.logger.error("Failed to ensure WireGuard config dir at %s", config_dir, exc_info=True) + return config_dir + def _resolve_wireguard_exe(self) -> str: candidates = [ str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"), @@ -138,6 +148,45 @@ class WireGuardServerManager: except Exception as exc: return 1, "", str(exc) + def _service_id(self) -> str: + return f"WireGuardTunnel${self._service_name}" + + def _query_service_state(self) -> Optional[str]: + code, out, err = self._run_command(["sc.exe", "query", self._service_id()]) + if code != 0: + return None + text = out or err + for line in text.splitlines(): + if "STATE" not in line: + continue + match = re.search(r"STATE\s*:\s*\d+\s+(\w+)", line) + if match: + return match.group(1).upper() + return None + + def _ensure_service_display_name(self) -> None: + if not self._service_display_name: + return + args = ["sc.exe", "config", self._service_id(), "DisplayName=", self._service_display_name] + code, out, err = self._run_command(args) + if code != 0 and err: + self.logger.warning("Failed to set WireGuard service display name: %s", err) + + def _ensure_service_running(self) -> None: + service_id = self._service_id() + for _ in range(6): + state = self._query_service_state() + if state == "RUNNING": + return + if state == "STOPPED": + code, out, err = self._run_command(["sc.exe", "start", service_id]) + if code != 0: + self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err) + break + time.sleep(1) + state = self._query_service_state() + raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})") + def _normalise_allowed_ports( self, candidate: Optional[Iterable[int]], @@ -242,7 +291,6 @@ class WireGuardServerManager: f"PrivateKey = {self.server_private_key}", f"ListenPort = {self.config.port}", f"Address = {iface}", - "SaveConfig = false", "", ] @@ -317,11 +365,11 @@ class WireGuardServerManager: """Render a temporary WireGuard config and start the service.""" try: - self._temp_dir.mkdir(parents=True, exist_ok=True) + self._config_dir.mkdir(parents=True, exist_ok=True) except Exception: self.logger.warning("Failed to create temp dir for WireGuard config", exc_info=True) - config_path = self._temp_dir / "borealis-wg.conf" + config_path = self._config_dir / f"{self._service_name}.conf" rendered = self.render_server_config(peers) config_path.write_text(rendered, encoding="utf-8") self.logger.info("Rendered WireGuard config to %s", config_path) @@ -335,6 +383,8 @@ class WireGuardServerManager: self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err) raise RuntimeError(f"WireGuard installtunnelservice failed: {err}") self.logger.info("WireGuard listener installed (service=%s)", config_path.stem) + self._ensure_service_display_name() + self._ensure_service_running() def stop_listener(self, *, ignore_missing: bool = False) -> None: """Stop and remove the WireGuard tunnel service."""