Additional Changes to VPN Tunneling

This commit is contained in:
2026-01-11 19:02:53 -07:00
parent 6ceb59f717
commit df14a1e26a
18 changed files with 681 additions and 175 deletions

View File

@@ -20,7 +20,7 @@ from flask_socketio import SocketIO
from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext
from ..VPN import VpnTunnelService
from ..VPN import WireGuardServerConfig, WireGuardServerManager, VpnTunnelService
from .vpn_shell import VpnShellBridge
@@ -63,12 +63,53 @@ class EngineRealtimeAdapters:
self.service_log = _make_service_logger(base, self.context.logger)
class AgentSocketRegistry:
def __init__(self, socketio: SocketIO, logger) -> None:
self.socketio = socketio
self.logger = logger
self._sid_by_agent: Dict[str, str] = {}
self._agent_by_sid: Dict[str, str] = {}
def register(self, agent_id: str, sid: str) -> None:
if not agent_id or not sid:
return
previous = self._sid_by_agent.get(agent_id)
if previous and previous != sid:
self._agent_by_sid.pop(previous, None)
self._sid_by_agent[agent_id] = sid
self._agent_by_sid[sid] = agent_id
def unregister(self, sid: str) -> Optional[str]:
agent_id = self._agent_by_sid.pop(sid, None)
if agent_id and self._sid_by_agent.get(agent_id) == sid:
self._sid_by_agent.pop(agent_id, None)
return agent_id
def emit(self, agent_id: str, event: str, payload: Any) -> bool:
sid = self._sid_by_agent.get(agent_id)
if not sid:
return False
try:
self.socketio.emit(event, payload, to=sid)
return True
except Exception:
self.logger.debug("Failed to emit %s to agent_id=%s", event, agent_id, exc_info=True)
return False
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime."""
adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs")
agent_logger = context.logger.getChild("realtime.agents")
shell_bridge = VpnShellBridge(socket_server, context)
agent_registry = AgentSocketRegistry(socket_server, agent_logger)
def _emit_agent_event(agent_id: str, event: str, payload: Any) -> bool:
return agent_registry.emit(agent_id, event, payload)
setattr(context, "emit_agent_event", _emit_agent_event)
def _get_tunnel_service() -> Optional[VpnTunnelService]:
service = getattr(context, "vpn_tunnel_service", None)
@@ -76,7 +117,22 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
manager = WireGuardServerManager(
WireGuardServerConfig(
port=context.wireguard_port,
engine_virtual_ip=context.wireguard_engine_virtual_ip,
peer_network=context.wireguard_peer_network,
private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.vpn_tunnel_log_path),
)
)
setattr(context, "wireguard_server_manager", manager)
except Exception:
context.logger.error("Failed to initialize WireGuard server manager on demand.", exc_info=True)
return None
try:
signer = signing.load_signer()
except Exception:
@@ -275,6 +331,29 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
service.bump_activity(agent_id)
return {"status": "ok"}
@socket_server.on("connect_agent")
def _connect_agent(data: Any) -> Dict[str, Any]:
agent_id = ""
service_mode = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
service_mode = str(data.get("service_mode") or "").strip().lower()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
agent_registry.register(agent_id, request.sid)
agent_logger.info("Agent socket registered agent_id=%s service_mode=%s sid=%s", agent_id, service_mode, request.sid)
service = _get_tunnel_service()
if service:
payload = service.session_payload(agent_id, include_token=True)
if payload:
agent_registry.emit(agent_id, "vpn_tunnel_start", payload)
return {"status": "ok"}
@socket_server.on("vpn_shell_send")
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
payload = None
@@ -288,10 +367,13 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
return {"status": "ok"}
@socket_server.on("vpn_shell_close")
def _vpn_shell_close() -> Dict[str, Any]:
def _vpn_shell_close(data: Any = None) -> Dict[str, Any]:
shell_bridge.close(request.sid)
return {"status": "ok"}
@socket_server.on("disconnect")
def _ws_disconnect() -> None:
agent_id = agent_registry.unregister(request.sid)
if agent_id:
agent_logger.info("Agent socket disconnected agent_id=%s sid=%s", agent_id, request.sid)
shell_bridge.close(request.sid)