Updates to WireGuard

This commit is contained in:
2026-01-13 22:22:31 -07:00
parent 22733f2c04
commit 7dda62d9ee
2 changed files with 199 additions and 5 deletions

View File

@@ -17,6 +17,7 @@ the client session with the issued config/token.
from __future__ import annotations from __future__ import annotations
import base64 import base64
import ipaddress
import json import json
import os import os
import subprocess import subprocess
@@ -25,6 +26,7 @@ import time
import re import re
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from urllib.parse import urlsplit
try: try:
import winreg # type: ignore import winreg # type: ignore
@@ -48,6 +50,7 @@ ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"] ROLE_CONTEXTS = ["system"]
TUNNEL_NAME = "Borealis" TUNNEL_NAME = "Borealis"
TUNNEL_DISPLAY_NAME = "Borealis" TUNNEL_DISPLAY_NAME = "Borealis"
SERVICE_DISPLAY_NAME = "Borealis - WireGuard - Agent"
TUNNEL_IDLE_ADDRESS = "169.254.255.254/32" TUNNEL_IDLE_ADDRESS = "169.254.255.254/32"
@@ -140,6 +143,7 @@ class WireGuardClient:
self.temp_root.mkdir(parents=True, exist_ok=True) self.temp_root.mkdir(parents=True, exist_ok=True)
self.service_name = TUNNEL_NAME self.service_name = TUNNEL_NAME
self.display_name = TUNNEL_DISPLAY_NAME self.display_name = TUNNEL_DISPLAY_NAME
self.service_display_name = SERVICE_DISPLAY_NAME
self.conf_path = self._wireguard_config_path() self.conf_path = self._wireguard_config_path()
self.session: Optional[SessionConfig] = None self.session: Optional[SessionConfig] = None
self.idle_deadline: Optional[float] = None self.idle_deadline: Optional[float] = None
@@ -295,6 +299,22 @@ class WireGuardClient:
] ]
self._run(args) self._run(args)
def _ensure_service_display_name(self) -> None:
if not self.service_display_name:
return
if not self._service_exists():
return
args = [
"sc.exe",
"config",
self._service_id(),
"DisplayName=",
self.service_display_name,
]
code, _, err = self._run(args)
if code != 0 and err:
_write_log(f"WireGuard service display name update failed: {err}")
def _ensure_idle_service(self) -> None: def _ensure_idle_service(self) -> None:
if self._service_exists(): if self._service_exists():
return return
@@ -305,6 +325,7 @@ class WireGuardClient:
return return
if self._install_service(): if self._install_service():
self._ensure_adapter_name() self._ensure_adapter_name()
self._ensure_service_display_name()
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None: def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
payload = dict(token or {}) payload = dict(token or {})
@@ -425,6 +446,7 @@ class WireGuardClient:
self._restart_service() self._restart_service()
self._ensure_adapter_name() self._ensure_adapter_name()
self._ensure_service_display_name()
self.session = session self.session = session
self.idle_deadline = time.time() + max(60, session.idle_seconds) self.idle_deadline = time.time() + max(60, session.idle_seconds)
@@ -449,6 +471,7 @@ class WireGuardClient:
if wrote_idle: if wrote_idle:
self._restart_service() self._restart_service()
self._ensure_adapter_name() self._ensure_adapter_name()
self._ensure_service_display_name()
_write_log(f"WireGuard client session stopped (reason={reason}).") _write_log(f"WireGuard client session stopped (reason={reason}).")
elif not ignore_missing: elif not ignore_missing:
_write_log("Failed to write idle WireGuard config.") _write_log("Failed to write idle WireGuard config.")
@@ -492,6 +515,98 @@ def _coerce_int(value: Any, default: int) -> int:
return default return default
def _parse_endpoint_host(value: Optional[str]) -> Optional[str]:
if not value:
return None
text = str(value).strip()
if not text:
return None
if text.startswith("["):
match = re.match(r"^\[([^\]]+)\]", text)
if match:
return match.group(1)
host, sep, _ = text.rpartition(":")
if sep and host:
return host
return text
def _parse_endpoint_port(value: Optional[str]) -> Optional[int]:
if not value:
return None
text = str(value).strip()
if not text:
return None
if text.startswith("["):
match = re.match(r"^\[[^\]]+\]:(\d+)$", text)
if match:
try:
return int(match.group(1))
except Exception:
return None
return None
_, sep, port = text.rpartition(":")
if sep and port.isdigit():
try:
return int(port)
except Exception:
return None
return None
def _format_endpoint(host: str, port: Optional[int]) -> Optional[str]:
if not host:
return None
text = str(host).strip()
if not text:
return None
if ":" in text and not text.startswith("["):
text = f"[{text}]"
if port is None:
return text
return f"{text}:{port}"
def _parse_server_url_host(server_url: Optional[str]) -> Optional[str]:
if not server_url:
return None
try:
parsed = urlsplit(str(server_url).strip())
if parsed.hostname:
return parsed.hostname.strip()
except Exception:
pass
text = str(server_url).strip()
if not text:
return None
if "://" in text:
text = text.split("://", 1)[1]
text = text.split("/", 1)[0].strip()
if text.startswith("[") and "]" in text:
text = text[1:text.index("]")]
if ":" in text:
host_part, port_part = text.rsplit(":", 1)
if port_part.isdigit():
text = host_part
text = text.strip()
return text or None
def _is_loopback_host(host: Optional[str]) -> bool:
if not host:
return True
text = str(host).strip().lower()
if not text:
return True
if text == "localhost":
return True
try:
ip = ipaddress.ip_address(text)
return ip.is_loopback or ip.is_unspecified
except Exception:
return False
class Role: class Role:
def __init__(self, ctx) -> None: def __init__(self, ctx) -> None:
self.ctx = ctx self.ctx = ctx
@@ -499,6 +614,7 @@ class Role:
hooks = getattr(ctx, "hooks", {}) or {} hooks = getattr(ctx, "hooks", {}) or {}
self._log_hook = hooks.get("log_agent") self._log_hook = hooks.get("log_agent")
self._http_client_factory = hooks.get("http_client") self._http_client_factory = hooks.get("http_client")
self._get_server_url = hooks.get("get_server_url")
try: try:
self.client.stop_session(reason="agent_startup", ignore_missing=True) self.client.stop_session(reason="agent_startup", ignore_missing=True)
except Exception: except Exception:
@@ -522,6 +638,33 @@ class Role:
return None return None
return None return None
def _resolve_endpoint(self, endpoint: Optional[str], token: Dict[str, Any]) -> Optional[str]:
server_url = None
if callable(self._get_server_url):
try:
server_url = self._get_server_url()
except Exception:
server_url = None
server_host = _parse_server_url_host(server_url)
if not server_host:
return endpoint
endpoint_host = _parse_endpoint_host(endpoint)
if endpoint_host and not _is_loopback_host(endpoint_host):
return endpoint
endpoint_port = _parse_endpoint_port(endpoint)
if endpoint_port is None:
endpoint_port = _coerce_int(token.get("port"), 0)
if endpoint_port <= 0:
endpoint_port = None
resolved = _format_endpoint(server_host, endpoint_port)
if resolved and endpoint and resolved != endpoint:
self._log(f"WireGuard endpoint override: {endpoint} -> {resolved}")
return resolved or endpoint
def _build_session(self, payload: Any) -> Optional[SessionConfig]: def _build_session(self, payload: Any) -> Optional[SessionConfig]:
if not isinstance(payload, dict): if not isinstance(payload, dict):
self._log("WireGuard start payload missing/invalid.", error=True) self._log("WireGuard start payload missing/invalid.", error=True)
@@ -539,6 +682,7 @@ class Role:
virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip") virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip")
endpoint = payload.get("endpoint") or payload.get("server_endpoint") endpoint = payload.get("endpoint") or payload.get("server_endpoint")
endpoint = self._resolve_endpoint(endpoint, token)
server_public_key = payload.get("server_public_key") or payload.get("public_key") server_public_key = payload.get("server_public_key") or payload.get("public_key")
engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip") engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip")

View File

@@ -19,8 +19,8 @@ import base64
import ipaddress import ipaddress
import logging import logging
import os import os
import re
import subprocess import subprocess
import tempfile
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from logging.handlers import TimedRotatingFileHandler from logging.handlers import TimedRotatingFileHandler
@@ -30,6 +30,7 @@ from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Uni
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.asymmetric import x25519
from ... import config as engine_config
def _build_logger(log_path: Path) -> logging.Logger: def _build_logger(log_path: Path) -> logging.Logger:
logger = logging.getLogger("borealis.engine.wireguard") logger = logging.getLogger("borealis.engine.wireguard")
@@ -72,9 +73,18 @@ class WireGuardServerManager:
self._ensure_cert_dir() self._ensure_cert_dir()
self.server_private_key, self.server_public_key = self._ensure_server_keys() self.server_private_key, self.server_public_key = self._ensure_server_keys()
self._service_name = "borealis-wg" self._service_name = "borealis-wg"
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine" self._service_display_name = "Borealis - WireGuard - Engine"
self._config_dir = self._resolve_config_dir()
self._wireguard_exe = self._resolve_wireguard_exe() self._wireguard_exe = self._resolve_wireguard_exe()
def _resolve_config_dir(self) -> Path:
config_dir = engine_config.PROJECT_ROOT / "Engine" / "WireGuard"
try:
config_dir.mkdir(parents=True, exist_ok=True)
except Exception:
self.logger.error("Failed to ensure WireGuard config dir at %s", config_dir, exc_info=True)
return config_dir
def _resolve_wireguard_exe(self) -> str: def _resolve_wireguard_exe(self) -> str:
candidates = [ candidates = [
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"), str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
@@ -138,6 +148,45 @@ class WireGuardServerManager:
except Exception as exc: except Exception as exc:
return 1, "", str(exc) return 1, "", str(exc)
def _service_id(self) -> str:
return f"WireGuardTunnel${self._service_name}"
def _query_service_state(self) -> Optional[str]:
code, out, err = self._run_command(["sc.exe", "query", self._service_id()])
if code != 0:
return None
text = out or err
for line in text.splitlines():
if "STATE" not in line:
continue
match = re.search(r"STATE\s*:\s*\d+\s+(\w+)", line)
if match:
return match.group(1).upper()
return None
def _ensure_service_display_name(self) -> None:
if not self._service_display_name:
return
args = ["sc.exe", "config", self._service_id(), "DisplayName=", self._service_display_name]
code, out, err = self._run_command(args)
if code != 0 and err:
self.logger.warning("Failed to set WireGuard service display name: %s", err)
def _ensure_service_running(self) -> None:
service_id = self._service_id()
for _ in range(6):
state = self._query_service_state()
if state == "RUNNING":
return
if state == "STOPPED":
code, out, err = self._run_command(["sc.exe", "start", service_id])
if code != 0:
self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err)
break
time.sleep(1)
state = self._query_service_state()
raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})")
def _normalise_allowed_ports( def _normalise_allowed_ports(
self, self,
candidate: Optional[Iterable[int]], candidate: Optional[Iterable[int]],
@@ -242,7 +291,6 @@ class WireGuardServerManager:
f"PrivateKey = {self.server_private_key}", f"PrivateKey = {self.server_private_key}",
f"ListenPort = {self.config.port}", f"ListenPort = {self.config.port}",
f"Address = {iface}", f"Address = {iface}",
"SaveConfig = false",
"", "",
] ]
@@ -317,11 +365,11 @@ class WireGuardServerManager:
"""Render a temporary WireGuard config and start the service.""" """Render a temporary WireGuard config and start the service."""
try: try:
self._temp_dir.mkdir(parents=True, exist_ok=True) self._config_dir.mkdir(parents=True, exist_ok=True)
except Exception: except Exception:
self.logger.warning("Failed to create temp dir for WireGuard config", exc_info=True) self.logger.warning("Failed to create temp dir for WireGuard config", exc_info=True)
config_path = self._temp_dir / "borealis-wg.conf" config_path = self._config_dir / f"{self._service_name}.conf"
rendered = self.render_server_config(peers) rendered = self.render_server_config(peers)
config_path.write_text(rendered, encoding="utf-8") config_path.write_text(rendered, encoding="utf-8")
self.logger.info("Rendered WireGuard config to %s", config_path) self.logger.info("Rendered WireGuard config to %s", config_path)
@@ -335,6 +383,8 @@ class WireGuardServerManager:
self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err) self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err)
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}") raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem) self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
self._ensure_service_display_name()
self._ensure_service_running()
def stop_listener(self, *, ignore_missing: bool = False) -> None: def stop_listener(self, *, ignore_missing: bool = False) -> None:
"""Stop and remove the WireGuard tunnel service.""" """Stop and remove the WireGuard tunnel service."""