mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-04 10:00:30 -07:00
Additional Changes to VPN Tunneling
This commit is contained in:
@@ -37,6 +37,7 @@ class VpnSession:
|
||||
firewall_rules: List[str] = field(default_factory=list)
|
||||
activity_id: Optional[int] = None
|
||||
hostname: Optional[str] = None
|
||||
endpoint_host: Optional[str] = None
|
||||
|
||||
|
||||
class VpnTunnelService:
|
||||
@@ -176,6 +177,37 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
|
||||
return token
|
||||
|
||||
def _ensure_token(self, session: VpnSession, *, now: Optional[float] = None) -> None:
|
||||
if not session:
|
||||
return
|
||||
current = now if now is not None else time.time()
|
||||
if session.expires_at > current + 30:
|
||||
return
|
||||
session.expires_at = current + 300
|
||||
session.token = self._issue_token(session.agent_id, session.tunnel_id, session.expires_at)
|
||||
|
||||
def _normalize_endpoint_host(self, host: Optional[str]) -> Optional[str]:
|
||||
if not host:
|
||||
return None
|
||||
try:
|
||||
text = str(host).strip()
|
||||
except Exception:
|
||||
return None
|
||||
return text or None
|
||||
|
||||
def _format_endpoint_host(self, host: str) -> str:
|
||||
if ":" in host and not host.startswith("["):
|
||||
return f"[{host}]"
|
||||
return host
|
||||
|
||||
def _service_log_event(self, message: str, *, level: str = "INFO") -> None:
|
||||
if not callable(self.service_log):
|
||||
return
|
||||
try:
|
||||
self.service_log("reverse_tunnel", message, level=level)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to write reverse_tunnel service log entry", exc_info=True)
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
for session in self._sessions_by_agent.values():
|
||||
@@ -192,14 +224,24 @@ class VpnTunnelService:
|
||||
return
|
||||
self.wg.start_listener(peers)
|
||||
|
||||
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
|
||||
def connect(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
operator_id: Optional[str],
|
||||
endpoint_host: Optional[str] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
now = time.time()
|
||||
normalized_host = self._normalize_endpoint_host(endpoint_host)
|
||||
with self._lock:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
if operator_id:
|
||||
existing.operator_ids.add(operator_id)
|
||||
if normalized_host and not existing.endpoint_host:
|
||||
existing.endpoint_host = normalized_host
|
||||
existing.last_activity = now
|
||||
self._ensure_token(existing, now=now)
|
||||
return self._session_payload(existing)
|
||||
|
||||
tunnel_id = uuid.uuid4().hex
|
||||
@@ -220,6 +262,7 @@ class VpnTunnelService:
|
||||
created_at=now,
|
||||
expires_at=now + 300,
|
||||
last_activity=now,
|
||||
endpoint_host=normalized_host,
|
||||
)
|
||||
if operator_id:
|
||||
session.operator_ids.add(operator_id)
|
||||
@@ -247,6 +290,17 @@ class VpnTunnelService:
|
||||
raise
|
||||
|
||||
payload = self._session_payload(session)
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_start agent_id={0} tunnel_id={1} virtual_ip={2} endpoint={3} allowed_ports={4} operators={5}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
session.virtual_ip,
|
||||
payload.get("endpoint", ""),
|
||||
",".join(str(p) for p in session.allowed_ports),
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
self._emit_start(payload)
|
||||
self._log_device_activity(session, event="start")
|
||||
return payload
|
||||
@@ -258,6 +312,22 @@ class VpnTunnelService:
|
||||
return None
|
||||
return self._session_payload(session, include_token=False)
|
||||
|
||||
def session_payload(self, agent_id: str, *, include_token: bool = True) -> Optional[Mapping[str, Any]]:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return None
|
||||
if include_token:
|
||||
self._ensure_token(session)
|
||||
return self._session_payload(session, include_token=include_token)
|
||||
|
||||
def request_agent_start(self, agent_id: str) -> Optional[Mapping[str, Any]]:
|
||||
payload = self.session_payload(agent_id, include_token=True)
|
||||
if not payload:
|
||||
return None
|
||||
self._emit_start(payload)
|
||||
return payload
|
||||
|
||||
def bump_activity(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
@@ -283,6 +353,15 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
|
||||
|
||||
self._refresh_listener()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_stop agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason,
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
self._emit_stop(session, reason)
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
@@ -297,6 +376,16 @@ class VpnTunnelService:
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
agent_id = None
|
||||
if isinstance(payload, Mapping):
|
||||
agent_id = payload.get("agent_id")
|
||||
emit_agent = getattr(self.context, "emit_agent_event", None)
|
||||
if agent_id and callable(emit_agent):
|
||||
try:
|
||||
if emit_agent(agent_id, "vpn_tunnel_start", payload):
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug("emit_agent_event failed for vpn_tunnel_start", exc_info=True)
|
||||
try:
|
||||
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
|
||||
except Exception:
|
||||
@@ -305,6 +394,17 @@ class VpnTunnelService:
|
||||
def _emit_stop(self, session: VpnSession, reason: str) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
emit_agent = getattr(self.context, "emit_agent_event", None)
|
||||
if callable(emit_agent):
|
||||
try:
|
||||
if emit_agent(
|
||||
session.agent_id,
|
||||
"vpn_tunnel_stop",
|
||||
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
|
||||
):
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug("emit_agent_event failed for vpn_tunnel_stop", exc_info=True)
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"vpn_tunnel_stop",
|
||||
@@ -454,13 +554,15 @@ class VpnTunnelService:
|
||||
pass
|
||||
|
||||
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
|
||||
endpoint_host = session.endpoint_host or str(self._engine_ip.ip)
|
||||
endpoint_host = self._format_endpoint_host(endpoint_host)
|
||||
payload: Dict[str, Any] = {
|
||||
"tunnel_id": session.tunnel_id,
|
||||
"agent_id": session.agent_id,
|
||||
"virtual_ip": session.virtual_ip,
|
||||
"engine_virtual_ip": str(self._engine_ip.ip),
|
||||
"allowed_ips": f"{self._engine_ip.ip}/32",
|
||||
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
|
||||
"endpoint": f"{endpoint_host}:{self.context.wireguard_port}",
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
|
||||
Reference in New Issue
Block a user