More VPN Tunnel Changes

This commit is contained in:
2026-01-11 20:53:09 -07:00
parent df14a1e26a
commit 3809fd25fb
13 changed files with 593 additions and 51 deletions

View File

@@ -15,6 +15,7 @@ import json
import threading
import time
import uuid
from datetime import datetime, timezone
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
@@ -66,6 +67,7 @@ class VpnTunnelService:
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._cleanup_listener()
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
@@ -79,6 +81,15 @@ class VpnTunnelService:
if session.last_activity + self.idle_seconds <= now:
expired.append(session)
for session in expired:
self._service_log_event(
"vpn_tunnel_idle_timeout agent_id={0} tunnel_id={1} last_activity={2} last_activity_iso={3} idle_seconds={4}".format(
session.agent_id,
session.tunnel_id,
int(session.last_activity),
self._ts_to_iso(session.last_activity),
self.idle_seconds,
)
)
self.disconnect(session.agent_id, reason="idle_timeout")
def _allocate_virtual_ip(self, agent_id: str) -> str:
@@ -200,13 +211,27 @@ class VpnTunnelService:
return f"[{host}]"
return host
def _ts_to_iso(self, ts: float) -> str:
try:
return datetime.fromtimestamp(ts, timezone.utc).isoformat()
except Exception:
return ""
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
if not callable(self.service_log):
return
try:
self.service_log("reverse_tunnel", message, level=level)
self.service_log("VPN_Tunnel/tunnel", message, level=level)
except Exception:
self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True)
self.logger.debug("Failed to write vpn_tunnel service log entry", exc_info=True)
def _cleanup_listener(self) -> None:
try:
self.wg.stop_listener(ignore_missing=True)
self._service_log_event("vpn_listener_cleanup reason=startup")
except Exception:
self.logger.debug("Failed to clean up WireGuard listener on startup.", exc_info=True)
self._service_log_event("vpn_listener_cleanup_failed reason=startup", level="WARNING")
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
@@ -220,8 +245,11 @@ class VpnTunnelService:
peer["public_key"] = session.client_public_key
peers.append(peer)
if not peers:
self._service_log_event("vpn_listener_stop reason=no_peers")
self.wg.stop_listener()
return
agent_list = ",".join(str(peer.get("agent_id", "")) for peer in peers if peer.get("agent_id"))
self._service_log_event("vpn_listener_start peers={0} agents={1}".format(len(peers), agent_list))
self.wg.start_listener(peers)
def connect(
@@ -233,6 +261,14 @@ class VpnTunnelService:
) -> Mapping[str, Any]:
now = time.time()
normalized_host = self._normalize_endpoint_host(endpoint_host)
operator_text = operator_id or "-"
self._service_log_event(
"vpn_tunnel_connect_request agent_id={0} operator={1} endpoint_host={2}".format(
agent_id or "-",
operator_text,
normalized_host or "-",
)
)
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
@@ -241,7 +277,18 @@ class VpnTunnelService:
if normalized_host and not existing.endpoint_host:
existing.endpoint_host = normalized_host
existing.last_activity = now
previous_expiry = existing.expires_at
self._ensure_token(existing, now=now)
refreshed = existing.expires_at != previous_expiry
operator_list = ",".join(sorted(filter(None, existing.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_session_reuse agent_id={0} tunnel_id={1} operators={2} token_refreshed={3}".format(
existing.agent_id,
existing.tunnel_id,
operator_list,
str(refreshed).lower(),
)
)
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
@@ -250,6 +297,7 @@ class VpnTunnelService:
client_private, client_public = self._generate_client_keys()
token = self._issue_token(agent_id, tunnel_id, now + 300)
self.wg.require_orchestration_token(token)
token_signed = "signature" in token
session = VpnSession(
tunnel_id=tunnel_id,
@@ -270,6 +318,16 @@ class VpnTunnelService:
self._sessions_by_tunnel[tunnel_id] = session
try:
self._service_log_event(
"vpn_tunnel_session_create agent_id={0} tunnel_id={1} virtual_ip={2} allowed_ports={3} token_signed={4} token_expires={5}".format(
session.agent_id,
session.tunnel_id,
session.virtual_ip,
",".join(str(p) for p in allowed_ports),
str(bool(token_signed)).lower(),
int(session.expires_at),
)
)
self._refresh_listener()
peer = self.wg.build_peer_profile(
@@ -279,7 +337,18 @@ class VpnTunnelService:
)
rule_names = self.wg.apply_firewall_rules(peer)
session.firewall_rules = rule_names
self._service_log_event(
"vpn_tunnel_firewall_applied agent_id={0} tunnel_id={1} rules={2}".format(
session.agent_id,
session.tunnel_id,
len(rule_names),
)
)
except Exception:
self._service_log_event(
"vpn_tunnel_connect_failed agent_id={0} tunnel_id={1}".format(agent_id, tunnel_id),
level="ERROR",
)
with self._lock:
self._sessions_by_agent.pop(agent_id, None)
self._sessions_by_tunnel.pop(tunnel_id, None)
@@ -312,6 +381,11 @@ class VpnTunnelService:
return None
return self._session_payload(session, include_token=False)
def list_sessions(self) -> List[Mapping[str, Any]]:
with self._lock:
sessions = sorted(self._sessions_by_agent.values(), key=lambda s: s.agent_id)
return [self._session_summary(session) for session in sessions]
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
@@ -324,7 +398,14 @@ class VpnTunnelService:
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
payload = self.session_payload(agent_id, include_token=True)
if not payload:
self._service_log_event("vpn_tunnel_agent_start_missing agent_id={0}".format(agent_id or "-"))
return None
self._service_log_event(
"vpn_tunnel_agent_start_emit agent_id={0} tunnel_id={1}".format(
payload.get("agent_id", "-"),
payload.get("tunnel_id", "-"),
)
)
self._emit_start(payload)
return payload
@@ -333,7 +414,18 @@ class VpnTunnelService:
session = self._sessions_by_agent.get(agent_id)
if not session:
return
session.last_activity = time.time()
now = time.time()
previous = session.last_activity
session.last_activity = now
idle_for = now - previous
if idle_for >= 60:
self._service_log_event(
"vpn_tunnel_activity_bump agent_id={0} tunnel_id={1} idle_for={2}".format(
session.agent_id,
session.tunnel_id,
int(idle_for),
)
)
try:
if self.socketio:
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
@@ -344,6 +436,9 @@ class VpnTunnelService:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
if not session:
self._service_log_event(
"vpn_tunnel_disconnect_missing agent_id={0} reason={1}".format(agent_id or "-", reason or "-")
)
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
@@ -370,6 +465,9 @@ class VpnTunnelService:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
self._service_log_event(
"vpn_tunnel_disconnect_missing tunnel_id={0} reason={1}".format(tunnel_id or "-", reason or "-")
)
return False
return self.disconnect(session.agent_id, reason=reason)
@@ -383,13 +481,27 @@ class VpnTunnelService:
if agent_id and callable(emit_agent):
try:
if emit_agent(agent_id, "vpn_tunnel_start", payload):
self._service_log_event(
"vpn_tunnel_start_emit agent_id={0} transport=direct".format(agent_id or "-")
)
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
self._service_log_event(
"vpn_tunnel_start_emit_failed agent_id={0} transport=direct".format(agent_id or "-"),
level="WARNING",
)
try:
self._service_log_event(
"vpn_tunnel_start_emit agent_id={0} transport=broadcast".format(agent_id or "-")
)
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
self._service_log_event(
"vpn_tunnel_start_emit_failed agent_id={0} transport=broadcast".format(agent_id or "-"),
level="WARNING",
)
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
@@ -402,10 +514,29 @@ class VpnTunnelService:
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
):
self._service_log_event(
"vpn_tunnel_stop_emit agent_id={0} tunnel_id={1} transport=direct".format(
session.agent_id,
session.tunnel_id,
)
)
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
self._service_log_event(
"vpn_tunnel_stop_emit_failed agent_id={0} tunnel_id={1} transport=direct".format(
session.agent_id,
session.tunnel_id,
),
level="WARNING",
)
try:
self._service_log_event(
"vpn_tunnel_stop_emit agent_id={0} tunnel_id={1} transport=broadcast".format(
session.agent_id,
session.tunnel_id,
)
)
self.socketio.emit(
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
@@ -413,6 +544,13 @@ class VpnTunnelService:
)
except Exception:
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
self._service_log_event(
"vpn_tunnel_stop_emit_failed agent_id={0} tunnel_id={1} transport=broadcast".format(
session.agent_id,
session.tunnel_id,
),
level="WARNING",
)
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
if self.db_conn_factory is None:
@@ -573,3 +711,24 @@ class VpnTunnelService:
if include_token:
payload["token"] = session.token
return payload
def _session_summary(self, session: VpnSession) -> Mapping[str, Any]:
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
endpoint_host = self._format_endpoint_host(endpoint_host)
return {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
"created_at": int(session.created_at),
"created_at_iso": self._ts_to_iso(session.created_at),
"last_activity": int(session.last_activity),
"last_activity_iso": self._ts_to_iso(session.last_activity),
"expires_at": int(session.expires_at),
"expires_at_iso": self._ts_to_iso(session.expires_at),
"idle_seconds": self.idle_seconds,
"status": "up",
}

View File

@@ -336,12 +336,16 @@ class WireGuardServerManager:
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
def stop_listener(self) -> None:
def stop_listener(self, *, ignore_missing: bool = False) -> None:
"""Stop and remove the WireGuard tunnel service."""
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args)
if code != 0:
err_text = " ".join([out or "", err or ""]).strip().lower()
if ignore_missing and ("does not exist" in err_text or "not exist" in err_text):
self.logger.info("WireGuard tunnel service already absent")
return
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
else:
self.logger.info("WireGuard tunnel service removed")