Additional Changes to VPN Tunneling

This commit is contained in:
2026-01-11 19:02:53 -07:00
parent 6ceb59f717
commit df14a1e26a
18 changed files with 681 additions and 175 deletions

View File

@@ -37,6 +37,7 @@ class VpnSession:
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
endpoint_host: Optional[str] = None
class VpnTunnelService:
@@ -176,6 +177,37 @@ class VpnTunnelService:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None:
if not session:
return
current = now if now is not None else time.time()
if session.expires_at > current + 30:
return
session.expires_at = current + 300
session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at)
def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]:
if not host:
return None
try:
text = str(host).strip()
except Exception:
return None
return text or None
def _format_endpoint_host(self, host: str) -> str:
if ":" in host and not host.startswith("["):
return f"[{host}]"
return host
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
if not callable(self.service_log):
return
try:
self.service_log("reverse_tunnel", message, level=level)
except Exception:
self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True)
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
@@ -192,14 +224,24 @@ class VpnTunnelService:
return
self.wg.start_listener(peers)
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
def connect(
self,
*,
agent_id: str,
operator_id: Optional[str],
endpoint_host: Optional[str] = None,
) -> Mapping[str, Any]:
now = time.time()
normalized_host = self._normalize_endpoint_host(endpoint_host)
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
if normalized_host and not existing.endpoint_host:
existing.endpoint_host = normalized_host
existing.last_activity = now
self._ensure_token(existing, now=now)
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
@@ -220,6 +262,7 @@ class VpnTunnelService:
created_at=now,
expires_at=now + 300,
last_activity=now,
endpoint_host=normalized_host,
)
if operator_id:
session.operator_ids.add(operator_id)
@@ -247,6 +290,17 @@ class VpnTunnelService:
raise
payload = self._session_payload(session)
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format(
session.agent_id,
session.tunnel_id,
session.virtual_ip,
payload.get("endpoint", ""),
",".join(str(p) for p in session.allowed_ports),
operator_text,
)
)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
@@ -258,6 +312,22 @@ class VpnTunnelService:
return None
return self._session_payload(session, include_token=False)
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
if include_token:
self._ensure_token(session)
return self._session_payload(session, include_token=include_token)
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
payload = self.session_payload(agent_id, include_token=True)
if not payload:
return None
self._emit_start(payload)
return payload
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
@@ -283,6 +353,15 @@ class VpnTunnelService:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
self._service_log_event(
"vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
session.agent_id,
session.tunnel_id,
reason,
operator_text,
)
)
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
@@ -297,6 +376,16 @@ class VpnTunnelService:
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
agent_id = None
if isinstance(payload, Mapping):
agent_id = payload.get("agent_id")
emit_agent = getattr(self.context, "emit_agent_event", None)
if agent_id and callable(emit_agent):
try:
if emit_agent(agent_id, "vpn_tunnel_start", payload):
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
try:
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
@@ -305,6 +394,17 @@ class VpnTunnelService:
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
emit_agent = getattr(self.context, "emit_agent_event", None)
if callable(emit_agent):
try:
if emit_agent(
session.agent_id,
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
):
return
except Exception:
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
try:
self.socketio.emit(
"vpn_tunnel_stop",
@@ -454,13 +554,15 @@ class VpnTunnelService:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
endpoint_host = self._format_endpoint_host(endpoint_host)
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,