mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-04 06:50:31 -07:00
Updates to WireGuard
This commit is contained in:
@@ -17,6 +17,7 @@ the client session with the issued config/token.
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
@@ -25,6 +26,7 @@ import time
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
try:
|
||||
import winreg # type: ignore
|
||||
@@ -48,6 +50,7 @@ ROLE_NAME = "WireGuardTunnel"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
TUNNEL_NAME = "Borealis"
|
||||
TUNNEL_DISPLAY_NAME = "Borealis"
|
||||
SERVICE_DISPLAY_NAME = "Borealis - WireGuard - Agent"
|
||||
TUNNEL_IDLE_ADDRESS = "169.254.255.254/32"
|
||||
|
||||
|
||||
@@ -140,6 +143,7 @@ class WireGuardClient:
|
||||
self.temp_root.mkdir(parents=True, exist_ok=True)
|
||||
self.service_name = TUNNEL_NAME
|
||||
self.display_name = TUNNEL_DISPLAY_NAME
|
||||
self.service_display_name = SERVICE_DISPLAY_NAME
|
||||
self.conf_path = self._wireguard_config_path()
|
||||
self.session: Optional[SessionConfig] = None
|
||||
self.idle_deadline: Optional[float] = None
|
||||
@@ -295,6 +299,22 @@ class WireGuardClient:
|
||||
]
|
||||
self._run(args)
|
||||
|
||||
def _ensure_service_display_name(self) -> None:
|
||||
if not self.service_display_name:
|
||||
return
|
||||
if not self._service_exists():
|
||||
return
|
||||
args = [
|
||||
"sc.exe",
|
||||
"config",
|
||||
self._service_id(),
|
||||
"DisplayName=",
|
||||
self.service_display_name,
|
||||
]
|
||||
code, _, err = self._run(args)
|
||||
if code != 0 and err:
|
||||
_write_log(f"WireGuard service display name update failed: {err}")
|
||||
|
||||
def _ensure_idle_service(self) -> None:
|
||||
if self._service_exists():
|
||||
return
|
||||
@@ -305,6 +325,7 @@ class WireGuardClient:
|
||||
return
|
||||
if self._install_service():
|
||||
self._ensure_adapter_name()
|
||||
self._ensure_service_display_name()
|
||||
|
||||
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
|
||||
payload = dict(token or {})
|
||||
@@ -425,6 +446,7 @@ class WireGuardClient:
|
||||
|
||||
self._restart_service()
|
||||
self._ensure_adapter_name()
|
||||
self._ensure_service_display_name()
|
||||
|
||||
self.session = session
|
||||
self.idle_deadline = time.time() + max(60, session.idle_seconds)
|
||||
@@ -449,6 +471,7 @@ class WireGuardClient:
|
||||
if wrote_idle:
|
||||
self._restart_service()
|
||||
self._ensure_adapter_name()
|
||||
self._ensure_service_display_name()
|
||||
_write_log(f"WireGuard client session stopped (reason={reason}).")
|
||||
elif not ignore_missing:
|
||||
_write_log("Failed to write idle WireGuard config.")
|
||||
@@ -492,6 +515,98 @@ def _coerce_int(value: Any, default: int) -> int:
|
||||
return default
|
||||
|
||||
|
||||
def _parse_endpoint_host(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
if text.startswith("["):
|
||||
match = re.match(r"^\[([^\]]+)\]", text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
host, sep, _ = text.rpartition(":")
|
||||
if sep and host:
|
||||
return host
|
||||
return text
|
||||
|
||||
|
||||
def _parse_endpoint_port(value: Optional[str]) -> Optional[int]:
|
||||
if not value:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
if text.startswith("["):
|
||||
match = re.match(r"^\[[^\]]+\]:(\d+)$", text)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
_, sep, port = text.rpartition(":")
|
||||
if sep and port.isdigit():
|
||||
try:
|
||||
return int(port)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _format_endpoint(host: str, port: Optional[int]) -> Optional[str]:
|
||||
if not host:
|
||||
return None
|
||||
text = str(host).strip()
|
||||
if not text:
|
||||
return None
|
||||
if ":" in text and not text.startswith("["):
|
||||
text = f"[{text}]"
|
||||
if port is None:
|
||||
return text
|
||||
return f"{text}:{port}"
|
||||
|
||||
|
||||
def _parse_server_url_host(server_url: Optional[str]) -> Optional[str]:
|
||||
if not server_url:
|
||||
return None
|
||||
try:
|
||||
parsed = urlsplit(str(server_url).strip())
|
||||
if parsed.hostname:
|
||||
return parsed.hostname.strip()
|
||||
except Exception:
|
||||
pass
|
||||
text = str(server_url).strip()
|
||||
if not text:
|
||||
return None
|
||||
if "://" in text:
|
||||
text = text.split("://", 1)[1]
|
||||
text = text.split("/", 1)[0].strip()
|
||||
if text.startswith("[") and "]" in text:
|
||||
text = text[1:text.index("]")]
|
||||
if ":" in text:
|
||||
host_part, port_part = text.rsplit(":", 1)
|
||||
if port_part.isdigit():
|
||||
text = host_part
|
||||
text = text.strip()
|
||||
return text or None
|
||||
|
||||
|
||||
def _is_loopback_host(host: Optional[str]) -> bool:
|
||||
if not host:
|
||||
return True
|
||||
text = str(host).strip().lower()
|
||||
if not text:
|
||||
return True
|
||||
if text == "localhost":
|
||||
return True
|
||||
try:
|
||||
ip = ipaddress.ip_address(text)
|
||||
return ip.is_loopback or ip.is_unspecified
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
@@ -499,6 +614,7 @@ class Role:
|
||||
hooks = getattr(ctx, "hooks", {}) or {}
|
||||
self._log_hook = hooks.get("log_agent")
|
||||
self._http_client_factory = hooks.get("http_client")
|
||||
self._get_server_url = hooks.get("get_server_url")
|
||||
try:
|
||||
self.client.stop_session(reason="agent_startup", ignore_missing=True)
|
||||
except Exception:
|
||||
@@ -522,6 +638,33 @@ class Role:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _resolve_endpoint(self, endpoint: Optional[str], token: Dict[str, Any]) -> Optional[str]:
|
||||
server_url = None
|
||||
if callable(self._get_server_url):
|
||||
try:
|
||||
server_url = self._get_server_url()
|
||||
except Exception:
|
||||
server_url = None
|
||||
|
||||
server_host = _parse_server_url_host(server_url)
|
||||
if not server_host:
|
||||
return endpoint
|
||||
|
||||
endpoint_host = _parse_endpoint_host(endpoint)
|
||||
if endpoint_host and not _is_loopback_host(endpoint_host):
|
||||
return endpoint
|
||||
|
||||
endpoint_port = _parse_endpoint_port(endpoint)
|
||||
if endpoint_port is None:
|
||||
endpoint_port = _coerce_int(token.get("port"), 0)
|
||||
if endpoint_port <= 0:
|
||||
endpoint_port = None
|
||||
|
||||
resolved = _format_endpoint(server_host, endpoint_port)
|
||||
if resolved and endpoint and resolved != endpoint:
|
||||
self._log(f"WireGuard endpoint override: {endpoint} -> {resolved}")
|
||||
return resolved or endpoint
|
||||
|
||||
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
|
||||
if not isinstance(payload, dict):
|
||||
self._log("WireGuard start payload missing/invalid.", error=True)
|
||||
@@ -539,6 +682,7 @@ class Role:
|
||||
|
||||
virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip")
|
||||
endpoint = payload.get("endpoint") or payload.get("server_endpoint")
|
||||
endpoint = self._resolve_endpoint(endpoint, token)
|
||||
server_public_key = payload.get("server_public_key") or payload.get("public_key")
|
||||
engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip")
|
||||
|
||||
|
||||
@@ -19,8 +19,8 @@ import base64
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
@@ -30,6 +30,7 @@ from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Uni
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
|
||||
from ... import config as engine_config
|
||||
|
||||
def _build_logger(log_path: Path) -> logging.Logger:
|
||||
logger = logging.getLogger("borealis.engine.wireguard")
|
||||
@@ -72,9 +73,18 @@ class WireGuardServerManager:
|
||||
self._ensure_cert_dir()
|
||||
self.server_private_key, self.server_public_key = self._ensure_server_keys()
|
||||
self._service_name = "borealis-wg"
|
||||
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
|
||||
self._service_display_name = "Borealis - WireGuard - Engine"
|
||||
self._config_dir = self._resolve_config_dir()
|
||||
self._wireguard_exe = self._resolve_wireguard_exe()
|
||||
|
||||
def _resolve_config_dir(self) -> Path:
|
||||
config_dir = engine_config.PROJECT_ROOT / "Engine" / "WireGuard"
|
||||
try:
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
self.logger.error("Failed to ensure WireGuard config dir at %s", config_dir, exc_info=True)
|
||||
return config_dir
|
||||
|
||||
def _resolve_wireguard_exe(self) -> str:
|
||||
candidates = [
|
||||
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
|
||||
@@ -138,6 +148,45 @@ class WireGuardServerManager:
|
||||
except Exception as exc:
|
||||
return 1, "", str(exc)
|
||||
|
||||
def _service_id(self) -> str:
|
||||
return f"WireGuardTunnel${self._service_name}"
|
||||
|
||||
def _query_service_state(self) -> Optional[str]:
|
||||
code, out, err = self._run_command(["sc.exe", "query", self._service_id()])
|
||||
if code != 0:
|
||||
return None
|
||||
text = out or err
|
||||
for line in text.splitlines():
|
||||
if "STATE" not in line:
|
||||
continue
|
||||
match = re.search(r"STATE\s*:\s*\d+\s+(\w+)", line)
|
||||
if match:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
def _ensure_service_display_name(self) -> None:
|
||||
if not self._service_display_name:
|
||||
return
|
||||
args = ["sc.exe", "config", self._service_id(), "DisplayName=", self._service_display_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0 and err:
|
||||
self.logger.warning("Failed to set WireGuard service display name: %s", err)
|
||||
|
||||
def _ensure_service_running(self) -> None:
|
||||
service_id = self._service_id()
|
||||
for _ in range(6):
|
||||
state = self._query_service_state()
|
||||
if state == "RUNNING":
|
||||
return
|
||||
if state == "STOPPED":
|
||||
code, out, err = self._run_command(["sc.exe", "start", service_id])
|
||||
if code != 0:
|
||||
self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err)
|
||||
break
|
||||
time.sleep(1)
|
||||
state = self._query_service_state()
|
||||
raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})")
|
||||
|
||||
def _normalise_allowed_ports(
|
||||
self,
|
||||
candidate: Optional[Iterable[int]],
|
||||
@@ -242,7 +291,6 @@ class WireGuardServerManager:
|
||||
f"PrivateKey = {self.server_private_key}",
|
||||
f"ListenPort = {self.config.port}",
|
||||
f"Address = {iface}",
|
||||
"SaveConfig = false",
|
||||
"",
|
||||
]
|
||||
|
||||
@@ -317,11 +365,11 @@ class WireGuardServerManager:
|
||||
"""Render a temporary WireGuard config and start the service."""
|
||||
|
||||
try:
|
||||
self._temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._config_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
self.logger.warning("Failed to create temp dir for WireGuard config", exc_info=True)
|
||||
|
||||
config_path = self._temp_dir / "borealis-wg.conf"
|
||||
config_path = self._config_dir / f"{self._service_name}.conf"
|
||||
rendered = self.render_server_config(peers)
|
||||
config_path.write_text(rendered, encoding="utf-8")
|
||||
self.logger.info("Rendered WireGuard config to %s", config_path)
|
||||
@@ -335,6 +383,8 @@ class WireGuardServerManager:
|
||||
self.logger.error("Failed to install WireGuard tunnel service code=%s err=%s", code, err)
|
||||
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
|
||||
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
|
||||
self._ensure_service_display_name()
|
||||
self._ensure_service_running()
|
||||
|
||||
def stop_listener(self, *, ignore_missing: bool = False) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
|
||||
Reference in New Issue
Block a user