Files
Borealis-Github-Replica/Data/Engine/services/VPN/vpn_tunnel_service.py
2026-01-11 20:53:09 -07:00

735 lines
29 KiB
Python

# ======================================================
# Data\Engine\services\VPN\vpn_tunnel_service.py
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
from __future__ import annotations
import base64
import ipaddress
import json
import threading
import time
import uuid
from datetime import datetime, timezone
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .wireguard_server import WireGuardServerManager
@dataclass
class VpnSession:
tunnel_id: str
agent_id: str
virtual_ip: str
token: Dict[str, Any]
client_public_key: str
client_private_key: str
allowed_ports: Tuple[int, ...]
created_at: float
expires_at: float
last_activity: float
operator_ids: set[str] = field(default_factory=set)
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
endpoint_host: Optional[str] = None
class VpnTunnelService:
def __init__(
self,
*,
context: Any,
wireguard_manager: WireGuardServerManager,
db_conn_factory,
socketio,
service_log,
signer: Optional[Any] = None,
idle_seconds: int = 900,
) -> None:
self.context = context
self.wg = wireguard_manager
self.db_conn_factory = db_conn_factory
self.socketio = socketio
self.service_log = service_log
self.signer = signer
self.logger = context.logger.getChild("vpn_tunnel")
self.activity_logger = self.wg.logger.getChild("device_activity")
self.idle_seconds = max(60, int(idle_seconds))
self._lock = threading.Lock()
self._sessions_by_agent: Dict[str, VpnSession] = {}
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._cleanup_listener()
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
def _idle_loop(self) -> None:
while True:
time.sleep(10)
now = time.time()
expired: List[VpnSession] = []
with self._lock:
for session in list(self._sessions_by_agent.values()):
if session.last_activity + self.idle_seconds <= now:
expired.append(session)
for session in expired:
self._service_log_event(
"vpn_tunnel_idle_timeout agent_id={0} tunnel_id={1} last_activity={2} last_activity_iso={3} idle_seconds={4}".format(
session.agent_id,
session.tunnel_id,
int(session.last_activity),
self._ts_to_iso(session.last_activity),
self.idle_seconds,
)
)
self.disconnect(session.agent_id, reason="idle_timeout")
def _allocate_virtual_ip(self, agent_id: str) -> str:
existing = self._sessions_by_agent.get(agent_id)
if existing:
return existing.virtual_ip
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
for host in self._peer_network.hosts():
if host == self._engine_ip.ip:
continue
candidate = f"{host}/32"
if candidate not in used:
return candidate
raise RuntimeError("vpn_ip_pool_exhausted")
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
try:
conn = self.db_conn_factory()
cur = conn.cursor()
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
os_name = str(row[0]).lower() if row and row[0] else ""
except Exception:
os_name = ""
if os_name and "windows" not in os_name:
baseline = {5900, 3478}
filtered = [p for p in default if p in baseline]
if filtered:
default = tuple(filtered)
cur.execute(
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return default
raw = row[0] or ""
ports = json.loads(raw) if raw else []
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
ports = [p for p in ports if 1 <= p <= 65535]
return tuple(dict.fromkeys(ports)) or default
except Exception:
return default
finally:
try:
conn.close()
except Exception:
pass
def _generate_client_keys(self) -> Tuple[str, str]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
key = x25519.X25519PrivateKey.generate()
priv = base64.b64encode(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode("ascii").strip()
pub = base64.b64encode(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode("ascii").strip()
return priv, pub
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
payload = {
"agent_id": agent_id,
"tunnel_id": tunnel_id,
"port": self.context.wireguard_port,
"expires_at": expires_at,
"issued_at": time.time(),
}
if not self.signer:
return dict(payload)
token = dict(payload)
try:
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
signature = self.signer.sign(payload_bytes)
token["signature"] = base64.b64encode(signature).decode("ascii")
if hasattr(self.signer, "public_base64_spki"):
token["signing_key"] = self.signer.public_base64_spki()
token["sig_alg"] = "ed25519"
except Exception:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None:
if not session:
return
current = now if now is not None else time.time()
if session.expires_at > current + 30:
return
session.expires_at = current + 300
session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at)
def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]:
if not host:
return None
try:
text = str(host).strip()
except Exception:
return None
return text or None
def _format_endpoint_host(self, host: str) -> str:
if ":" in host and not host.startswith("["):
return f"[{host}]"
return host
def _ts_to_iso(self, ts: float) -> str:
try:
return datetime.fromtimestamp(ts, timezone.utc).isoformat()
except Exception:
return ""
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
if not callable(self.service_log):
return
try:
self.service_log("VPN_Tunnel/tunnel", message, level=level)
except Exception:
self.logger.debug("Failed to write vpn_tunnel service log entry", exc_info=True)
def _cleanup_listener(self) -> None:
try:
self.wg.stop_listener(ignore_missing=True)
self._service_log_event("vpn_listener_cleanup reason=startup")
except Exception:
self.logger.debug("Failed to clean up WireGuard listener on startup.", exc_info=True)
self._service_log_event("vpn_listener_cleanup_failed reason=startup", level="WARNING")
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
peer = self.wg.build_peer_profile(
session.agent_id,
session.virtual_ip,
allowed_ports=session.allowed_ports,
)
peer = dict(peer)
peer["public_key"] = session.client_public_key
peers.append(peer)
if not peers:
self._service_log_event("vpn_listener_stop reason=no_peers")
self.wg.stop_listener()
return
agent_list = ",".join(str(peer.get("agent_id", "")) for peer in peers if peer.get("agent_id"))
self._service_log_event("vpn_listener_start peers={0} agents={1}".format(len(peers), agent_list))
self.wg.start_listener(peers)
def connect(
self,
*,
agent_id: str,
operator_id: Optional[str],
endpoint_host: Optional[str] = None,
) -> Mapping[str, Any]:
now = time.time()
normalized_host = self._normalize_endpoint_host(endpoint_host)
operator_text = operator_id or "-"
self._service_log_event(
"vpn_tunnel_connect_request agent_id={0} operator={1} endpoint_host={2}".format(
agent_id or "-",
operator_text,
normalized_host or "-",
)
)
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
if normalized_host and not existing.endpoint_host:
existing.endpoint_host = normalized_host
existing.last_activity = now
previous_expiry = existing.expires_at
self._ensure_token(existing, now=now)
refreshed = existing.expires_at != previous_expiry
operator_list = ",".join(sorted(filter(None, existing.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_session_reuse agent_id={0} tunnel_id={1} operators={2} token_refreshed={3}".format(
existing.agent_id,
existing.tunnel_id,
operator_list,
str(refreshed).lower(),
)
)
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
virtual_ip = self._allocate_virtual_ip(agent_id)
allowed_ports = self._load_allowed_ports(agent_id)
client_private, client_public = self._generate_client_keys()
token = self._issue_token(agent_id, tunnel_id, now + 300)
self.wg.require_orchestration_token(token)
token_signed = "signature" in token
session = VpnSession(
tunnel_id=tunnel_id,
agent_id=agent_id,
virtual_ip=virtual_ip,
token=token,
client_public_key=client_public,
client_private_key=client_private,
allowed_ports=allowed_ports,
created_at=now,
expires_at=now + 300,
last_activity=now,
endpoint_host=normalized_host,
)
if operator_id:
session.operator_ids.add(operator_id)
self._sessions_by_agent[agent_id] = session
self._sessions_by_tunnel[tunnel_id] = session
try:
self._service_log_event(
"vpn_tunnel_session_create agent_id={0} tunnel_id={1} virtual_ip={2} allowed_ports={3} token_signed={4} token_expires={5}".format(
session.agent_id,
session.tunnel_id,
session.virtual_ip,
",".join(str(p) for p in allowed_ports),
str(bool(token_signed)).lower(),
int(session.expires_at),
)
)
self._refresh_listener()
peer = self.wg.build_peer_profile(
agent_id,
virtual_ip,
allowed_ports=allowed_ports,
)
rule_names = self.wg.apply_firewall_rules(peer)
session.firewall_rules = rule_names
self._service_log_event(
"vpn_tunnel_firewall_applied agent_id={0} tunnel_id={1} rules={2}".format(
session.agent_id,
session.tunnel_id,
len(rule_names),
)
)
except Exception:
self._service_log_event(
"vpn_tunnel_connect_failed agent_id={0} tunnel_id={1}".format(agent_id, tunnel_id),
level="ERROR",
)
with self._lock:
self._sessions_by_agent.pop(agent_id, None)
self._sessions_by_tunnel.pop(tunnel_id, None)
try:
self._refresh_listener()
except Exception:
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
raise
payload = self._session_payload(session)
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format(
session.agent_id,
session.tunnel_id,
session.virtual_ip,
payload.get("endpoint", ""),
",".join(str(p) for p in session.allowed_ports),
operator_text,
)
)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
return self._session_payload(session, include_token=False)
def list_sessions(self) -> List[Mapping[str, Any]]:
with self._lock:
sessions = sorted(self._sessions_by_agent.values(), key=lambda s: s.agent_id)
return [self._session_summary(session) for session in sessions]
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
if include_token:
self._ensure_token(session)
return self._session_payload(session, include_token=include_token)
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
payload = self.session_payload(agent_id, include_token=True)
if not payload:
self._service_log_event("vpn_tunnel_agent_start_missing agent_id={0}".format(agent_id or "-"))
return None
self._service_log_event(
"vpn_tunnel_agent_start_emit agent_id={0} tunnel_id={1}".format(
payload.get("agent_id", "-"),
payload.get("tunnel_id", "-"),
)
)
self._emit_start(payload)
return payload
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return
now = time.time()
previous = session.last_activity
session.last_activity = now
idle_for = now - previous
if idle_for >= 60:
self._service_log_event(
"vpn_tunnel_activity_bump agent_id={0} tunnel_id={1} idle_for={2}".format(
session.agent_id,
session.tunnel_id,
int(idle_for),
)
)
try:
if self.socketio:
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
if not session:
self._service_log_event(
"vpn_tunnel_disconnect_missing agent_id={0} reason={1}".format(agent_id or "-", reason or "-")
)
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
try:
self.wg.remove_firewall_rules(session.firewall_rules)
except Exception:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
session.agent_id,
session.tunnel_id,
reason,
operator_text,
)
)
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
self._service_log_event(
"vpn_tunnel_disconnect_missing tunnel_id={0} reason={1}".format(tunnel_id or "-", reason or "-")
)
return False
return self.disconnect(session.agent_id, reason=reason)
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
agent_id = None
if isinstance(payload, Mapping):
agent_id = payload.get("agent_id")
emit_agent = getattr(self.context, "emit_agent_event", None)
if agent_id and callable(emit_agent):
try:
if emit_agent(agent_id, "vpn_tunnel_start", payload):
self._service_log_event(
"vpn_tunnel_start_emit agent_id={0} transport=direct".format(agent_id or "-")
)
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
self._service_log_event(
"vpn_tunnel_start_emit_failed agent_id={0} transport=direct".format(agent_id or "-"),
level="WARNING",
)
try:
self._service_log_event(
"vpn_tunnel_start_emit agent_id={0} transport=broadcast".format(agent_id or "-")
)
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
self._service_log_event(
"vpn_tunnel_start_emit_failed agent_id={0} transport=broadcast".format(agent_id or "-"),
level="WARNING",
)
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
emit_agent = getattr(self.context, "emit_agent_event", None)
if callable(emit_agent):
try:
if emit_agent(
session.agent_id,
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
):
self._service_log_event(
"vpn_tunnel_stop_emit agent_id={0} tunnel_id={1} transport=direct".format(
session.agent_id,
session.tunnel_id,
)
)
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
self._service_log_event(
"vpn_tunnel_stop_emit_failed agent_id={0} tunnel_id={1} transport=direct".format(
session.agent_id,
session.tunnel_id,
),
level="WARNING",
)
try:
self._service_log_event(
"vpn_tunnel_stop_emit agent_id={0} tunnel_id={1} transport=broadcast".format(
session.agent_id,
session.tunnel_id,
)
)
self.socketio.emit(
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
namespace="/",
)
except Exception:
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
self._service_log_event(
"vpn_tunnel_stop_emit_failed agent_id={0} tunnel_id={1} transport=broadcast".format(
session.agent_id,
session.tunnel_id,
),
level="WARNING",
)
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
if self.db_conn_factory is None:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
conn = None
try:
conn = self.db_conn_factory()
cur = conn.cursor()
hostname = session.hostname
if not hostname:
try:
cur.execute(
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
(session.agent_id,),
)
row = cur.fetchone()
if row and row[0]:
hostname = str(row[0]).strip()
session.hostname = hostname
except Exception:
hostname = None
if not hostname:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
now_ts = int(time.time())
script_name = "Reverse VPN Tunnel (WireGuard)"
if event == "start":
cur.execute(
"""
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
VALUES(?,?,?,?,?,?,?,?)
""",
(
hostname,
session.tunnel_id,
script_name,
"vpn_tunnel",
now_ts,
"Running",
"",
"",
),
)
session.activity_id = cur.lastrowid
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "created",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
session.activity_id or "-",
)
return
if session.activity_id:
status = "Completed" if event == "stop" else "Closed"
cur.execute(
"""
UPDATE activity_history
SET status=?,
stderr=COALESCE(stderr, '') || ?
WHERE id=?
""",
(
status,
f"\nreason: {reason}" if reason else "",
session.activity_id,
),
)
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "updated",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
event,
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
session.activity_id or "-",
)
except Exception:
self.activity_logger.debug(
"device_activity logging failed for tunnel_id=%s",
session.tunnel_id,
exc_info=True,
)
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
endpoint_host = self._format_endpoint_host(endpoint_host)
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,
"idle_seconds": self.idle_seconds,
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
}
if include_token:
payload["token"] = session.token
return payload
def _session_summary(self, session: VpnSession) -> Mapping[str, Any]:
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
endpoint_host = self._format_endpoint_host(endpoint_host)
return {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
"created_at": int(session.created_at),
"created_at_iso": self._ts_to_iso(session.created_at),
"last_activity": int(session.last_activity),
"last_activity_iso": self._ts_to_iso(session.last_activity),
"expires_at": int(session.expires_at),
"expires_at_iso": self._ts_to_iso(session.expires_at),
"idle_seconds": self.idle_seconds,
"status": "up",
}