mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-04 09:00:30 -07:00
Additional Changes to VPN Tunneling
This commit is contained in:
@@ -20,7 +20,7 @@ from flask_socketio import SocketIO
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from ...server import EngineContext
|
||||
from ..VPN import VpnTunnelService
|
||||
from ..VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
|
||||
from .vpn_shell import VpnShellBridge
|
||||
|
||||
|
||||
@@ -63,12 +63,53 @@ class EngineRealtimeAdapters:
|
||||
self.service_log = _make_service_logger(base, self.context.logger)
|
||||
|
||||
|
||||
class AgentSocketRegistry:
|
||||
def __init__(self, socketio: SocketIO, logger) -> None:
|
||||
self.socketio = socketio
|
||||
self.logger = logger
|
||||
self._sid_by_agent: Dict[str, str] = {}
|
||||
self._agent_by_sid: Dict[str, str] = {}
|
||||
|
||||
def register(self, agent_id: str, sid: str) -> None:
|
||||
if not agent_id or not sid:
|
||||
return
|
||||
previous = self._sid_by_agent.get(agent_id)
|
||||
if previous and previous != sid:
|
||||
self._agent_by_sid.pop(previous, None)
|
||||
self._sid_by_agent[agent_id] = sid
|
||||
self._agent_by_sid[sid] = agent_id
|
||||
|
||||
def unregister(self, sid: str) -> Optional[str]:
|
||||
agent_id = self._agent_by_sid.pop(sid, None)
|
||||
if agent_id and self._sid_by_agent.get(agent_id) == sid:
|
||||
self._sid_by_agent.pop(agent_id, None)
|
||||
return agent_id
|
||||
|
||||
def emit(self, agent_id: str, event: str, payload: Any) -> bool:
|
||||
sid = self._sid_by_agent.get(agent_id)
|
||||
if not sid:
|
||||
return False
|
||||
try:
|
||||
self.socketio.emit(event, payload, to=sid)
|
||||
return True
|
||||
except Exception:
|
||||
self.logger.debug("Failed to emit %s to agent_id=%s", event, agent_id, exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
"""Register Socket.IO event handlers for the Engine runtime."""
|
||||
|
||||
adapters = EngineRealtimeAdapters(context)
|
||||
logger = context.logger.getChild("realtime.quick_jobs")
|
||||
agent_logger = context.logger.getChild("realtime.agents")
|
||||
shell_bridge = VpnShellBridge(socket_server, context)
|
||||
agent_registry = AgentSocketRegistry(socket_server, agent_logger)
|
||||
|
||||
def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool:
|
||||
return agent_registry.emit(agent_id, event, payload)
|
||||
|
||||
setattr(context, "emit_agent_event", _emit_agent_event)
|
||||
|
||||
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
||||
service = getattr(context, "vpn_tunnel_service", None)
|
||||
@@ -76,7 +117,22 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
return service
|
||||
manager = getattr(context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
return None
|
||||
try:
|
||||
manager = WireGuardServerManager(
|
||||
WireGuardServerConfig(
|
||||
port=context.wireguard_port,
|
||||
engine_virtual_ip=context.wireguard_engine_virtual_ip,
|
||||
peer_network=context.wireguard_peer_network,
|
||||
private_key_path=Path(context.wireguard_server_private_key_path),
|
||||
public_key_path=Path(context.wireguard_server_public_key_path),
|
||||
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
|
||||
log_path=Path(context.vpn_tunnel_log_path),
|
||||
)
|
||||
)
|
||||
setattr(context, "wireguard_server_manager", manager)
|
||||
except Exception:
|
||||
context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
|
||||
return None
|
||||
try:
|
||||
signer = signing.load_signer()
|
||||
except Exception:
|
||||
@@ -275,6 +331,29 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
service.bump_activity(agent_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("connect_agent")
|
||||
def _connect_agent(data: Any) -> Dict[str, Any]:
|
||||
agent_id = ""
|
||||
service_mode = ""
|
||||
if isinstance(data, dict):
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
service_mode = str(data.get("service_mode") or "").strip().lower()
|
||||
elif isinstance(data, str):
|
||||
agent_id = data.strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}
|
||||
|
||||
agent_registry.register(agent_id, request.sid)
|
||||
agent_logger.info("Agent socket registered agent_id=%s service_mode=%s sid=%s", agent_id, service_mode, request.sid)
|
||||
|
||||
service = _get_tunnel_service()
|
||||
if service:
|
||||
payload = service.session_payload(agent_id, include_token=True)
|
||||
if payload:
|
||||
agent_registry.emit(agent_id, "vpn_tunnel_start", payload)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("vpn_shell_send")
|
||||
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
||||
payload = None
|
||||
@@ -288,10 +367,13 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("vpn_shell_close")
|
||||
def _vpn_shell_close() -> Dict[str, Any]:
|
||||
def _vpn_shell_close(data: Any = None) -> Dict[str, Any]:
|
||||
shell_bridge.close(request.sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("disconnect")
|
||||
def _ws_disconnect() -> None:
|
||||
agent_id = agent_registry.unregister(request.sid)
|
||||
if agent_id:
|
||||
agent_logger.info("Agent socket disconnected agent_id=%s sid=%s", agent_id, request.sid)
|
||||
shell_bridge.close(request.sid)
|
||||
|
||||
@@ -13,6 +13,7 @@ import base64
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -42,7 +43,10 @@ class ShellSession:
|
||||
buffer = b""
|
||||
try:
|
||||
while True:
|
||||
data = self.tcp.recv(4096)
|
||||
try:
|
||||
data = self.tcp.recv(4096)
|
||||
except (socket.timeout, TimeoutError):
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
@@ -100,12 +104,28 @@ class VpnShellBridge:
|
||||
return None
|
||||
host = str(status.get("virtual_ip") or "").split("/")[0]
|
||||
port = int(self.context.wireguard_shell_port)
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
|
||||
tcp = None
|
||||
last_error: Optional[Exception] = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
break
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if attempt == 0:
|
||||
try:
|
||||
service.request_agent_start(agent_id)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to re-emit vpn_tunnel_start for agent=%s", agent_id, exc_info=True)
|
||||
time.sleep(1)
|
||||
if tcp is None:
|
||||
self.logger.warning("Failed to connect vpn shell to %s:%s", host, port, exc_info=last_error)
|
||||
return None
|
||||
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
|
||||
try:
|
||||
session.tcp.settimeout(15)
|
||||
except Exception:
|
||||
pass
|
||||
self._sessions[sid] = session
|
||||
session.start_reader()
|
||||
return session
|
||||
@@ -124,4 +144,3 @@ class VpnShellBridge:
|
||||
if not session:
|
||||
return
|
||||
session.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user