Updates to WireGuard

This commit is contained in:
2026-01-13 22:22:31 -07:00
parent 22733f2c04
commit 7dda62d9ee
2 changed files with 199 additions and 5 deletions

View File

@@ -17,6 +17,7 @@ the client session with the issued config/token.
from __future__ import annotations
import base64
import ipaddress
import json
import os
import subprocess
@@ -25,6 +26,7 @@ import time
import re
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import urlsplit
try:
import winreg # type: ignore
@@ -48,6 +50,7 @@ ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"]
TUNNEL_NAME = "Borealis"
TUNNEL_DISPLAY_NAME = "Borealis"
SERVICE_DISPLAY_NAME = "Borealis - WireGuard - Agent"
TUNNEL_IDLE_ADDRESS = "169.254.255.254/32"
@@ -140,6 +143,7 @@ class WireGuardClient:
self.temp_root.mkdir(parents=True, exist_ok=True)
self.service_name = TUNNEL_NAME
self.display_name = TUNNEL_DISPLAY_NAME
self.service_display_name = SERVICE_DISPLAY_NAME
self.conf_path = self._wireguard_config_path()
self.session: Optional[SessionConfig] = None
self.idle_deadline: Optional[float] = None
@@ -295,6 +299,22 @@ class WireGuardClient:
]
self._run(args)
def _ensure_service_display_name(self) -> None:
if not self.service_display_name:
return
if not self._service_exists():
return
args = [
"sc.exe",
"config",
self._service_id(),
"DisplayName=",
self.service_display_name,
]
code, _, err = self._run(args)
if code != 0 and err:
_write_log(f"WireGuard service display name update failed: {err}")
def _ensure_idle_service(self) -> None:
if self._service_exists():
return
@@ -305,6 +325,7 @@ class WireGuardClient:
return
if self._install_service():
self._ensure_adapter_name()
self._ensure_service_display_name()
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
payload = dict(token or {})
@@ -425,6 +446,7 @@ class WireGuardClient:
self._restart_service()
self._ensure_adapter_name()
self._ensure_service_display_name()
self.session = session
self.idle_deadline = time.time() + max(60, session.idle_seconds)
@@ -449,6 +471,7 @@ class WireGuardClient:
if wrote_idle:
self._restart_service()
self._ensure_adapter_name()
self._ensure_service_display_name()
_write_log(f"WireGuard client session stopped (reason={reason}).")
elif not ignore_missing:
_write_log("Failed to write idle WireGuard config.")
@@ -492,6 +515,98 @@ def _coerce_int(value: Any, default: int) -> int:
return default
def _parse_endpoint_host(value: Optional[str]) -> Optional[str]:
if not value:
return None
text = str(value).strip()
if not text:
return None
if text.startswith("["):
match = re.match(r"^\[([^\]]+)\]", text)
if match:
return match.group(1)
host, sep, _ = text.rpartition(":")
if sep and host:
return host
return text
def _parse_endpoint_port(value: Optional[str]) -> Optional[int]:
if not value:
return None
text = str(value).strip()
if not text:
return None
if text.startswith("["):
match = re.match(r"^\[[^\]]+\]:(\d+)$", text)
if match:
try:
return int(match.group(1))
except Exception:
return None
return None
_, sep, port = text.rpartition(":")
if sep and port.isdigit():
try:
return int(port)
except Exception:
return None
return None
def _format_endpoint(host: str, port: Optional[int]) -> Optional[str]:
if not host:
return None
text = str(host).strip()
if not text:
return None
if ":" in text and not text.startswith("["):
text = f"[{text}]"
if port is None:
return text
return f"{text}:{port}"
def _parse_server_url_host(server_url: Optional[str]) -> Optional[str]:
if not server_url:
return None
try:
parsed = urlsplit(str(server_url).strip())
if parsed.hostname:
return parsed.hostname.strip()
except Exception:
pass
text = str(server_url).strip()
if not text:
return None
if "://" in text:
text = text.split("://", 1)[1]
text = text.split("/", 1)[0].strip()
if text.startswith("[") and "]" in text:
text = text[1:text.index("]")]
if ":" in text:
host_part, port_part = text.rsplit(":", 1)
if port_part.isdigit():
text = host_part
text = text.strip()
return text or None
def _is_loopback_host(host: Optional[str]) -> bool:
if not host:
return True
text = str(host).strip().lower()
if not text:
return True
if text == "localhost":
return True
try:
ip = ipaddress.ip_address(text)
return ip.is_loopback or ip.is_unspecified
except Exception:
return False
class Role:
def __init__(self, ctx) -> None:
self.ctx = ctx
@@ -499,6 +614,7 @@ class Role:
hooks = getattr(ctx, "hooks", {}) or {}
self._log_hook = hooks.get("log_agent")
self._http_client_factory = hooks.get("http_client")
self._get_server_url = hooks.get("get_server_url")
try:
self.client.stop_session(reason="agent_startup", ignore_missing=True)
except Exception:
@@ -522,6 +638,33 @@ class Role:
return None
return None
def _resolve_endpoint(self, endpoint: Optional[str], token: Dict[str, Any]) -> Optional[str]:
server_url = None
if callable(self._get_server_url):
try:
server_url = self._get_server_url()
except Exception:
server_url = None
server_host = _parse_server_url_host(server_url)
if not server_host:
return endpoint
endpoint_host = _parse_endpoint_host(endpoint)
if endpoint_host and not _is_loopback_host(endpoint_host):
return endpoint
endpoint_port = _parse_endpoint_port(endpoint)
if endpoint_port is None:
endpoint_port = _coerce_int(token.get("port"), 0)
if endpoint_port <= 0:
endpoint_port = None
resolved = _format_endpoint(server_host, endpoint_port)
if resolved and endpoint and resolved != endpoint:
self._log(f"WireGuard endpoint override: {endpoint} -> {resolved}")
return resolved or endpoint
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
if not isinstance(payload, dict):
self._log("WireGuard start payload missing/invalid.", error=True)
@@ -539,6 +682,7 @@ class Role:
virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip")
endpoint = payload.get("endpoint") or payload.get("server_endpoint")
endpoint = self._resolve_endpoint(endpoint, token)
server_public_key = payload.get("server_public_key") or payload.get("public_key")
engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip")