Files
Borealis-Github-Replica/Data/Agent/Roles/role_WireGuardTunnel.py

600 lines
23 KiB
Python

# ======================================================
# Data\Agent\Roles\role_WireGuardTunnel.py
# Description: WireGuard client lifecycle for outbound reverse VPN (Windows) with host-only /32 routing.
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard client role (Windows) for reverse VPN tunnels.
This role prepares the WireGuard client config, manages a single active
session, enforces idle teardown, and logs lifecycle events to
Agent/Logs/VPN_Tunnel/tunnel.log. It binds to Engine Socket.IO events
(`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop
the client session with the issued config/token.
"""
from __future__ import annotations
import base64
import json
import os
import subprocess
import threading
import time
import re
from pathlib import Path
from typing import Any, Dict, Optional
try:
import winreg # type: ignore
except Exception: # pragma: no cover - non-Windows guard
winreg = None
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
try:
from signature_utils import verify_and_store_script_signature
except Exception: # pragma: no cover - fallback for runtime path issues
import sys
from pathlib import Path as _Path
base_dir = _Path(__file__).resolve().parents[1]
if str(base_dir) not in sys.path:
sys.path.insert(0, str(base_dir))
from signature_utils import verify_and_store_script_signature
ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"]
TUNNEL_NAME = "Borealis"
TUNNEL_DISPLAY_NAME = "Borealis"
TUNNEL_IDLE_ADDRESS = "169.254.255.254/32"
def _log_path() -> Path:
root = Path(__file__).resolve().parents[2] / "Logs" / "VPN_Tunnel"
root.mkdir(parents=True, exist_ok=True)
return root / "tunnel.log"
def _write_log(message: str) -> None:
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
try:
_log_path().open("a", encoding="utf-8").write(f"[{ts}] [wg-client] {message}\n")
except Exception:
pass
def _encode_key(raw: bytes) -> str:
return base64.b64encode(raw).decode("ascii").strip()
def _generate_client_keys(root: Path) -> Dict[str, str]:
root.mkdir(parents=True, exist_ok=True)
priv_path = root / "client_private.key"
pub_path = root / "client_public.key"
if priv_path.is_file() and pub_path.is_file():
try:
private_key = priv_path.read_text(encoding="utf-8").strip()
public_key = pub_path.read_text(encoding="utf-8").strip()
if private_key and public_key:
return {"private": private_key, "public": public_key}
except Exception:
pass
key = x25519.X25519PrivateKey.generate()
priv = _encode_key(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
)
pub = _encode_key(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
)
try:
priv_path.write_text(priv, encoding="utf-8")
pub_path.write_text(pub, encoding="utf-8")
except Exception:
_write_log("Failed to persist WireGuard client keys.")
return {"private": priv, "public": pub}
class SessionConfig:
def __init__(
self,
*,
token: Dict[str, Any],
virtual_ip: str,
allowed_ips: str,
endpoint: str,
server_public_key: str,
allowed_ports: str,
idle_seconds: int = 900,
preshared_key: Optional[str] = None,
client_private_key: Optional[str] = None,
client_public_key: Optional[str] = None,
) -> None:
self.token = token
self.virtual_ip = virtual_ip
self.allowed_ips = allowed_ips
self.endpoint = endpoint
self.server_public_key = server_public_key
self.allowed_ports = allowed_ports
self.idle_seconds = idle_seconds
self.preshared_key = preshared_key
self.client_private_key = client_private_key
self.client_public_key = client_public_key
class WireGuardClient:
def __init__(self) -> None:
base = Path(__file__).resolve().parents[2]
self.cert_root = base / "Borealis" / "Certificates" / "VPN_Client"
self.temp_root = base / "Borealis" / "Temp"
self.temp_root.mkdir(parents=True, exist_ok=True)
self.service_name = TUNNEL_NAME
self.display_name = TUNNEL_DISPLAY_NAME
self.conf_path = self._wireguard_config_path()
self.session: Optional[SessionConfig] = None
self.idle_deadline: Optional[float] = None
self._idle_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._session_lock = threading.Lock()
self._client_keys = _generate_client_keys(self.cert_root)
self._wg_exe = self._resolve_wireguard_exe()
self._last_install_already_present = False
try:
self._ensure_idle_service()
except Exception:
pass
def _resolve_wireguard_exe(self) -> str:
candidates = [
str(Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "wireguard.exe"),
"wireguard.exe",
]
for candidate in candidates:
if Path(candidate).is_file():
return candidate
return "wireguard.exe"
def _service_id(self) -> str:
return f"WireGuardTunnel${self.service_name}"
def _service_reg_path(self) -> str:
return f"SYSTEM\\CurrentControlSet\\Services\\{self._service_id()}"
def _service_reg_exists(self) -> bool:
if winreg is None:
return False
try:
winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, self._service_reg_path())
return True
except FileNotFoundError:
return False
except PermissionError:
_write_log("WireGuard service registry check denied; treating as present.")
return True
except Exception as exc:
_write_log(f"WireGuard service registry check failed: {exc}")
return False
def _service_image_path(self) -> Optional[str]:
if winreg is None:
return None
try:
key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, self._service_reg_path())
value, _ = winreg.QueryValueEx(key, "ImagePath")
return str(value) if value else None
except Exception:
return None
def _service_config_path(self) -> Optional[Path]:
image_path = self._service_image_path()
if not image_path:
return None
match = re.search(r'(?i)/tunnelservice\s+"([^"]+)"', image_path)
if match:
return Path(match.group(1))
match = re.search(r"(?i)/tunnelservice\s+(\S+)", image_path)
if match:
return Path(match.group(1))
return None
def _wireguard_config_path(self) -> Path:
settings_dir = self.temp_root.parent / "Settings" / "WireGuard"
candidates = [
settings_dir,
Path(os.environ.get("ProgramFiles", "C:\\Program Files")) / "WireGuard" / "Data" / "Configurations",
Path(os.environ.get("ProgramData", "C:\\ProgramData")) / "Borealis" / "WireGuard" / "Configurations",
self.temp_root,
]
for config_dir in candidates:
candidate = config_dir / f"{self.service_name}.conf"
if candidate.is_file():
return candidate
for config_dir in candidates:
try:
config_dir.mkdir(parents=True, exist_ok=True)
return config_dir / f"{self.service_name}.conf"
except Exception:
continue
return self.temp_root / f"{self.service_name}.conf"
def _write_config(self, text: str) -> bool:
return self._write_config_to(self.conf_path, text)
def _write_config_to(self, path: Path, text: str) -> bool:
try:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(text, encoding="ascii")
return True
except Exception as exc:
_write_log(f"Failed to write WireGuard config at {path}: {exc}")
return False
def _render_idle_config(self) -> str:
private_key = self._client_keys["private"]
return "\n".join(
[
"[Interface]",
f"PrivateKey = {private_key}",
f"Address = {TUNNEL_IDLE_ADDRESS}",
"ListenPort = 0",
]
)
def _service_exists(self) -> bool:
code, _, _ = self._run(["sc.exe", "query", self._service_id()])
if code == 0:
return True
return self._service_reg_exists()
def _install_service(self) -> bool:
code, out, err = self._run([self._wg_exe, "/installtunnelservice", str(self.conf_path)])
self._last_install_already_present = False
if code != 0:
if "already installed and running" in err.lower():
self._last_install_already_present = True
_write_log("WireGuard tunnel service already installed; skipping install.")
return True
if "access is denied" in err.lower():
_write_log("Failed to install WireGuard tunnel service: access denied; ensure agent runs elevated.")
return False
_write_log(f"Failed to install WireGuard tunnel service: code={code} err={err}")
return False
return True
def _restart_service(self) -> bool:
service_id = self._service_id()
stop_code, _, stop_err = self._run(["sc.exe", "stop", service_id])
if stop_code != 0 and stop_err:
_write_log(f"WireGuard stop service returned code={stop_code} err={stop_err}")
time.sleep(1)
start_code, _, start_err = self._run(["sc.exe", "start", service_id])
if start_code != 0 and start_err:
_write_log(f"WireGuard start service returned code={start_code} err={start_err}")
return start_code == 0
def _ensure_adapter_name(self) -> None:
if self.service_name == self.display_name:
return
args = [
"netsh.exe",
"interface",
"set",
"interface",
f'name="{self.service_name}"',
f'newname="{self.display_name}"',
]
self._run(args)
def _ensure_idle_service(self) -> None:
if self._service_exists():
return
if not Path(self._wg_exe).is_file():
return
idle_config = self._render_idle_config()
if not self._write_config(idle_config):
return
if self._install_service():
self._ensure_adapter_name()
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
payload = dict(token or {})
signature = payload.pop("signature", None)
signing_key = payload.pop("signing_key", None)
sig_alg = payload.pop("sig_alg", None)
required = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required if field not in token or token[field] in ("", None)]
if missing:
raise ValueError(f"Missing token fields: {', '.join(missing)}")
try:
exp = float(payload["expires_at"])
except Exception:
raise ValueError("Invalid token expiry")
if exp <= time.time():
raise ValueError("Token expired")
try:
port = int(payload["port"])
except Exception:
raise ValueError("Invalid token port")
if port < 1 or port > 65535:
raise ValueError("Invalid token port")
if not signature:
if sig_alg or signing_key:
raise ValueError("Token signature missing")
stored_key = None
if signing_client is not None and hasattr(signing_client, "load_server_signing_key"):
try:
stored_key = signing_client.load_server_signing_key()
except Exception:
stored_key = None
if isinstance(stored_key, str) and stored_key.strip():
raise ValueError("Token signature missing")
return
if signature:
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
raise ValueError("Unsupported token signature algorithm")
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key):
raise ValueError("Token signature invalid")
def _run(self, args: list[str]) -> tuple[int, str, str]:
try:
proc = subprocess.run(args, capture_output=True, text=True, check=False)
return proc.returncode, proc.stdout.strip(), proc.stderr.strip()
except Exception as exc: # pragma: no cover - runtime guard
return 1, "", str(exc)
def _render_config(self, session: SessionConfig) -> str:
private_key = session.client_private_key or self._client_keys["private"]
lines = [
"[Interface]",
f"PrivateKey = {private_key}",
f"Address = {session.virtual_ip}",
"",
"[Peer]",
f"PublicKey = {session.server_public_key}",
f"AllowedIPs = {session.allowed_ips}",
f"Endpoint = {session.endpoint}",
"PersistentKeepalive = 20",
]
if session.preshared_key:
lines.append(f"PresharedKey = {session.preshared_key}")
return "\n".join(lines)
def _start_idle_monitor(self) -> None:
self._stop_event.clear()
def _loop() -> None:
while not self._stop_event.is_set():
if self.idle_deadline and time.time() >= self.idle_deadline:
_write_log("Idle timeout reached; stopping WireGuard session.")
self.stop_session(reason="idle_timeout")
return
time.sleep(5)
t = threading.Thread(target=_loop, daemon=True)
t.start()
self._idle_thread = t
def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None:
with self._session_lock:
if self.session:
_write_log("Rejecting start_session: existing session already active.")
return
try:
self._validate_token(session.token, signing_client=signing_client)
except Exception as exc:
_write_log(f"Refusing to start WireGuard session: {exc}")
return
rendered = self._render_config(session)
if not self._write_config(rendered):
_write_log("Failed to write WireGuard client config.")
return
_write_log(f"Rendered WireGuard client config to {self.conf_path}")
service_config_path = self._service_config_path()
if service_config_path and service_config_path != self.conf_path:
if self._write_config_to(service_config_path, rendered):
_write_log(f"Rendered WireGuard client config to service path {service_config_path}")
if not self._service_exists():
if not self._install_service():
return
service_present = self._service_exists()
if not service_present and self._last_install_already_present:
_write_log("WireGuard tunnel service presence inferred from install response.")
service_present = True
if not service_present:
_write_log("WireGuard tunnel service still missing after install attempt.")
return
self._restart_service()
self._ensure_adapter_name()
self.session = session
self.idle_deadline = time.time() + max(60, session.idle_seconds)
_write_log("WireGuard client session started; idle timer armed.")
self._start_idle_monitor()
def stop_session(self, reason: str = "stop", ignore_missing: bool = False) -> None:
with self._session_lock:
if not self._service_exists():
if not ignore_missing:
_write_log("WireGuard tunnel service not found when stopping session.")
self.session = None
self.idle_deadline = None
self._stop_event.set()
return
idle_config = self._render_idle_config()
wrote_idle = self._write_config(idle_config)
service_config_path = self._service_config_path()
if service_config_path and service_config_path != self.conf_path:
wrote_idle = self._write_config_to(service_config_path, idle_config) or wrote_idle
if wrote_idle:
self._restart_service()
self._ensure_adapter_name()
_write_log(f"WireGuard client session stopped (reason={reason}).")
elif not ignore_missing:
_write_log("Failed to write idle WireGuard config.")
self.session = None
self.idle_deadline = None
self._stop_event.set()
def bump_activity(self) -> None:
if self.session and self.idle_deadline:
self.idle_deadline = time.time() + max(60, self.session.idle_seconds)
_write_log("WireGuard client activity bump; idle timer reset.")
client = WireGuardClient()
def _parse_allowed_ips(value: Any, fallback: Optional[str]) -> Optional[str]:
if isinstance(value, list):
if not value:
return fallback
return str(value[0])
if isinstance(value, str) and value.strip():
return value.strip()
return fallback
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except Exception:
return default
class Role:
def __init__(self, ctx) -> None:
self.ctx = ctx
self.client = client
hooks = getattr(ctx, "hooks", {}) or {}
self._log_hook = hooks.get("log_agent")
self._http_client_factory = hooks.get("http_client")
try:
self.client.stop_session(reason="agent_startup", ignore_missing=True)
except Exception:
self._log("Failed to preflight WireGuard session cleanup.", error=True)
def _log(self, message: str, *, error: bool = False) -> None:
if callable(self._log_hook):
try:
self._log_hook(message, fname="VPN_Tunnel/tunnel.log")
if error:
self._log_hook(message, fname="agent.error.log")
except Exception:
pass
_write_log(message)
def _http_client(self) -> Optional[Any]:
try:
if callable(self._http_client_factory):
return self._http_client_factory()
except Exception:
return None
return None
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
if not isinstance(payload, dict):
self._log("WireGuard start payload missing/invalid.", error=True)
return None
payload_agent_id = payload.get("agent_id") or payload.get("agent_guid")
if payload_agent_id:
if str(payload_agent_id).strip() != str(self.ctx.agent_id).strip():
return None
token = payload.get("token") or payload.get("orchestration_token")
if not isinstance(token, dict):
self._log("WireGuard start missing token payload.", error=True)
return None
virtual_ip = payload.get("virtual_ip") or payload.get("client_virtual_ip")
endpoint = payload.get("endpoint") or payload.get("server_endpoint")
server_public_key = payload.get("server_public_key") or payload.get("public_key")
engine_virtual_ip = payload.get("engine_virtual_ip") or payload.get("engine_ip")
allowed_ips = _parse_allowed_ips(payload.get("allowed_ips"), engine_virtual_ip)
if not allowed_ips:
self._log("WireGuard start missing allowed_ips/engine_virtual_ip.", error=True)
return None
if "," in allowed_ips or allowed_ips.endswith("/0") or "/32" not in allowed_ips:
self._log("WireGuard allowed_ips must be a single /32.", error=True)
return None
if not virtual_ip or not endpoint or not server_public_key:
self._log("WireGuard start missing required fields.", error=True)
return None
if "/32" not in str(virtual_ip):
self._log("WireGuard virtual_ip must be /32.", error=True)
return None
idle_seconds = _coerce_int(payload.get("idle_seconds"), 900)
allowed_ports = payload.get("allowed_ports")
if isinstance(allowed_ports, list):
allowed_ports = ",".join(str(p) for p in allowed_ports)
allowed_ports = str(allowed_ports or "")
return SessionConfig(
token=token,
virtual_ip=str(virtual_ip),
allowed_ips=str(allowed_ips),
endpoint=str(endpoint),
server_public_key=str(server_public_key),
allowed_ports=allowed_ports,
idle_seconds=idle_seconds,
preshared_key=payload.get("preshared_key"),
client_private_key=payload.get("client_private_key"),
client_public_key=payload.get("client_public_key"),
)
def register_events(self) -> None:
sio = self.ctx.sio
@sio.on("vpn_tunnel_start")
async def _vpn_tunnel_start(payload):
session = self._build_session(payload)
if not session:
return
self._log("WireGuard start request received.")
self.client.start_session(session, signing_client=self._http_client())
@sio.on("vpn_tunnel_stop")
async def _vpn_tunnel_stop(payload):
reason = "server_stop"
if isinstance(payload, dict):
target_agent = payload.get("agent_id")
if target_agent and str(target_agent).strip() != str(self.ctx.agent_id).strip():
return
reason = payload.get("reason") or reason
self._log(f"WireGuard stop requested (reason={reason}).")
self.client.stop_session(reason=str(reason))
@sio.on("vpn_tunnel_activity")
async def _vpn_tunnel_activity(payload):
self.client.bump_activity()
def stop_all(self) -> None:
try:
self.client.stop_session(reason="agent_shutdown")
except Exception:
self._log("Failed to stop WireGuard client during shutdown.", error=True)