# ====================================================== # Data\Engine\services\VPN\vpn_tunnel_service.py # Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling). # # API Endpoints (if applicable): None # ====================================================== """WireGuard tunnel orchestration helpers for the Engine runtime.""" from __future__ import annotations import base64 import ipaddress import json import threading import time import uuid from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple from .wireguard_server import WireGuardServerManager @dataclass class VpnSession: tunnel_id: str agent_id: str virtual_ip: str token: Dict[str, Any] client_public_key: str client_private_key: str allowed_ports: Tuple[int, ...] created_at: float expires_at: float last_activity: float operator_ids: set[str] = field(default_factory=set) firewall_rules: List[str] = field(default_factory=list) activity_id: Optional[int] = None hostname: Optional[str] = None endpoint_host: Optional[str] = None class VpnTunnelService: def __init__( self, *, context: Any, wireguard_manager: WireGuardServerManager, db_conn_factory, socketio, service_log, signer: Optional[Any] = None, idle_seconds: int = 900, ) -> None: self.context = context self.wg = wireguard_manager self.db_conn_factory = db_conn_factory self.socketio = socketio self.service_log = service_log self.signer = signer self.logger = context.logger.getChild("vpn_tunnel") self.activity_logger = self.wg.logger.getChild("device_activity") self.idle_seconds = max(60, int(idle_seconds)) self._lock = threading.Lock() self._sessions_by_agent: Dict[str, VpnSession] = {} self._sessions_by_tunnel: Dict[str, VpnSession] = {} self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip) self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False) self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True) self._idle_thread.start() def _idle_loop(self) -> None: while True: time.sleep(10) now = time.time() expired: List[VpnSession] = [] with self._lock: for session in list(self._sessions_by_agent.values()): if session.last_activity + self.idle_seconds <= now: expired.append(session) for session in expired: self.disconnect(session.agent_id, reason="idle_timeout") def _allocate_virtual_ip(self, agent_id: str) -> str: existing = self._sessions_by_agent.get(agent_id) if existing: return existing.virtual_ip used = {s.virtual_ip for s in self._sessions_by_agent.values()} for host in self._peer_network.hosts(): if host == self._engine_ip.ip: continue candidate = f"{host}/32" if candidate not in used: return candidate raise RuntimeError("vpn_ip_pool_exhausted") def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]: default = tuple(self.context.wireguard_acl_allowlist_windows or ()) try: conn = self.db_conn_factory() cur = conn.cursor() try: cur.execute( "SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1", (agent_id,), ) row = cur.fetchone() os_name = str(row[0]).lower() if row and row[0] else "" except Exception: os_name = "" if os_name and "windows" not in os_name: baseline = {5900, 3478} filtered = [p for p in default if p in baseline] if filtered: default = tuple(filtered) cur.execute( "SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?", (agent_id,), ) row = cur.fetchone() if not row: return default raw = row[0] or "" ports = json.loads(raw) if raw else [] ports = [int(p) for p in ports if isinstance(p, (int, float, str))] ports = [p for p in ports if 1 <= p <= 65535] return tuple(dict.fromkeys(ports)) or default except Exception: return default finally: try: conn.close() except Exception: pass def _generate_client_keys(self) -> Tuple[str, str]: from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 key = x25519.X25519PrivateKey.generate() priv = base64.b64encode( key.private_bytes( encoding=serialization.Encoding.Raw, format=serialization.PrivateFormat.Raw, encryption_algorithm=serialization.NoEncryption(), ) ).decode("ascii").strip() pub = base64.b64encode( key.public_key().public_bytes( encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw, ) ).decode("ascii").strip() return priv, pub def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]: payload = { "agent_id": agent_id, "tunnel_id": tunnel_id, "port": self.context.wireguard_port, "expires_at": expires_at, "issued_at": time.time(), } if not self.signer: return dict(payload) token = dict(payload) try: payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") signature = self.signer.sign(payload_bytes) token["signature"] = base64.b64encode(signature).decode("ascii") if hasattr(self.signer, "public_base64_spki"): token["signing_key"] = self.signer.public_base64_spki() token["sig_alg"] = "ed25519" except Exception: self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True) return token def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None: if not session: return current = now if now is not None else time.time() if session.expires_at > current + 30: return session.expires_at = current + 300 session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at) def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]: if not host: return None try: text = str(host).strip() except Exception: return None return text or None def _format_endpoint_host(self, host: str) -> str: if ":" in host and not host.startswith("["): return f"[{host}]" return host def _service_log_event(self, message: str, *, level: str = "INFO") -> None: if not callable(self.service_log): return try: self.service_log("reverse_tunnel", message, level=level) except Exception: self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True) def _refresh_listener(self) -> None: peers: List[Mapping[str, object]] = [] for session in self._sessions_by_agent.values(): peer = self.wg.build_peer_profile( session.agent_id, session.virtual_ip, allowed_ports=session.allowed_ports, ) peer = dict(peer) peer["public_key"] = session.client_public_key peers.append(peer) if not peers: self.wg.stop_listener() return self.wg.start_listener(peers) def connect( self, *, agent_id: str, operator_id: Optional[str], endpoint_host: Optional[str] = None, ) -> Mapping[str, Any]: now = time.time() normalized_host = self._normalize_endpoint_host(endpoint_host) with self._lock: existing = self._sessions_by_agent.get(agent_id) if existing: if operator_id: existing.operator_ids.add(operator_id) if normalized_host and not existing.endpoint_host: existing.endpoint_host = normalized_host existing.last_activity = now self._ensure_token(existing, now=now) return self._session_payload(existing) tunnel_id = uuid.uuid4().hex virtual_ip = self._allocate_virtual_ip(agent_id) allowed_ports = self._load_allowed_ports(agent_id) client_private, client_public = self._generate_client_keys() token = self._issue_token(agent_id, tunnel_id, now + 300) self.wg.require_orchestration_token(token) session = VpnSession( tunnel_id=tunnel_id, agent_id=agent_id, virtual_ip=virtual_ip, token=token, client_public_key=client_public, client_private_key=client_private, allowed_ports=allowed_ports, created_at=now, expires_at=now + 300, last_activity=now, endpoint_host=normalized_host, ) if operator_id: session.operator_ids.add(operator_id) self._sessions_by_agent[agent_id] = session self._sessions_by_tunnel[tunnel_id] = session try: self._refresh_listener() peer = self.wg.build_peer_profile( agent_id, virtual_ip, allowed_ports=allowed_ports, ) rule_names = self.wg.apply_firewall_rules(peer) session.firewall_rules = rule_names except Exception: with self._lock: self._sessions_by_agent.pop(agent_id, None) self._sessions_by_tunnel.pop(tunnel_id, None) try: self._refresh_listener() except Exception: self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True) raise payload = self._session_payload(session) operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-" self._service_log_event( "vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format( session.agent_id, session.tunnel_id, session.virtual_ip, payload.get("endpoint", ""), ",".join(str(p) for p in session.allowed_ports), operator_text, ) ) self._emit_start(payload) self._log_device_activity(session, event="start") return payload def status(self, agent_id: str) -> Optional[Mapping[str, Any]]: with self._lock: session = self._sessions_by_agent.get(agent_id) if not session: return None return self._session_payload(session, include_token=False) def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]: with self._lock: session = self._sessions_by_agent.get(agent_id) if not session: return None if include_token: self._ensure_token(session) return self._session_payload(session, include_token=include_token) def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]: payload = self.session_payload(agent_id, include_token=True) if not payload: return None self._emit_start(payload) return payload def bump_activity(self, agent_id: str) -> None: with self._lock: session = self._sessions_by_agent.get(agent_id) if not session: return session.last_activity = time.time() try: if self.socketio: self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/") except Exception: self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True) def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool: with self._lock: session = self._sessions_by_agent.pop(agent_id, None) if not session: return False self._sessions_by_tunnel.pop(session.tunnel_id, None) try: self.wg.remove_firewall_rules(session.firewall_rules) except Exception: self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True) self._refresh_listener() operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-" self._service_log_event( "vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format( session.agent_id, session.tunnel_id, reason, operator_text, ) ) self._emit_stop(session, reason) self._log_device_activity(session, event="stop", reason=reason) return True def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool: with self._lock: session = self._sessions_by_tunnel.get(tunnel_id) if not session: return False return self.disconnect(session.agent_id, reason=reason) def _emit_start(self, payload: Mapping[str, Any]) -> None: if not self.socketio: return agent_id = None if isinstance(payload, Mapping): agent_id = payload.get("agent_id") emit_agent = getattr(self.context, "emit_agent_event", None) if agent_id and callable(emit_agent): try: if emit_agent(agent_id, "vpn_tunnel_start", payload): return except Exception: self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True) try: self.socketio.emit("vpn_tunnel_start", payload, namespace="/") except Exception: self.logger.debug("vpn_tunnel_start emit failed", exc_info=True) def _emit_stop(self, session: VpnSession, reason: str) -> None: if not self.socketio: return emit_agent = getattr(self.context, "emit_agent_event", None) if callable(emit_agent): try: if emit_agent( session.agent_id, "vpn_tunnel_stop", {"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason}, ): return except Exception: self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True) try: self.socketio.emit( "vpn_tunnel_stop", {"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason}, namespace="/", ) except Exception: self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True) def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None: if self.db_conn_factory is None: self.activity_logger.info( "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s", event, session.agent_id, session.tunnel_id, ",".join(sorted(filter(None, session.operator_ids))) or "-", reason or "-", ) return conn = None try: conn = self.db_conn_factory() cur = conn.cursor() hostname = session.hostname if not hostname: try: cur.execute( "SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1", (session.agent_id,), ) row = cur.fetchone() if row and row[0]: hostname = str(row[0]).strip() session.hostname = hostname except Exception: hostname = None if not hostname: self.activity_logger.info( "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown", event, session.agent_id, session.tunnel_id, ",".join(sorted(filter(None, session.operator_ids))) or "-", reason or "-", ) return now_ts = int(time.time()) script_name = "Reverse VPN Tunnel (WireGuard)" if event == "start": cur.execute( """ INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr) VALUES(?,?,?,?,?,?,?,?) """, ( hostname, session.tunnel_id, script_name, "vpn_tunnel", now_ts, "Running", "", "", ), ) session.activity_id = cur.lastrowid conn.commit() if self.socketio: try: self.socketio.emit( "device_activity_changed", { "hostname": hostname, "activity_id": session.activity_id, "change": "created", "source": "vpn_tunnel", }, ) except Exception: pass self.activity_logger.info( "device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s", hostname, session.agent_id, session.tunnel_id, ",".join(sorted(filter(None, session.operator_ids))) or "-", session.activity_id or "-", ) return if session.activity_id: status = "Completed" if event == "stop" else "Closed" cur.execute( """ UPDATE activity_history SET status=?, stderr=COALESCE(stderr, '') || ? WHERE id=? """, ( status, f"\nreason: {reason}" if reason else "", session.activity_id, ), ) conn.commit() if self.socketio: try: self.socketio.emit( "device_activity_changed", { "hostname": hostname, "activity_id": session.activity_id, "change": "updated", "source": "vpn_tunnel", }, ) except Exception: pass self.activity_logger.info( "device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s", event, hostname, session.agent_id, session.tunnel_id, ",".join(sorted(filter(None, session.operator_ids))) or "-", reason or "-", session.activity_id or "-", ) except Exception: self.activity_logger.debug( "device_activity logging failed for tunnel_id=%s", session.tunnel_id, exc_info=True, ) finally: if conn is not None: try: conn.close() except Exception: pass def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]: endpoint_host = session.endpoint_host or str(self._engine_ip.ip) endpoint_host = self._format_endpoint_host(endpoint_host) payload: Dict[str, Any] = { "tunnel_id": session.tunnel_id, "agent_id": session.agent_id, "virtual_ip": session.virtual_ip, "engine_virtual_ip": str(self._engine_ip.ip), "allowed_ips": f"{self._engine_ip.ip}/32", "endpoint": f"{endpoint_host}:{self.context.wireguard_port}", "server_public_key": self.wg.server_public_key, "client_public_key": session.client_public_key, "client_private_key": session.client_private_key, "idle_seconds": self.idle_seconds, "allowed_ports": list(session.allowed_ports), "connected_operators": len([o for o in session.operator_ids if o]), } if include_token: payload["token"] = session.token return payload