From 6ceb59f717942240f7902faee49bde2272be0b2b Mon Sep 17 00:00:00 2001 From: Nicole Rappe Date: Thu, 18 Dec 2025 01:35:03 -0700 Subject: [PATCH] Overhaul of VPN Codebase --- Data/Agent/Roles/ReverseTunnel/__init__.py | 2 - .../Roles/ReverseTunnel/tunnel_Powershell.py | 190 --- Data/Agent/Roles/Reverse_Tunnels/__init__.py | 3 - .../Protocols/Bash.py | 49 - .../Protocols/Powershell.py | 35 - .../Protocols/__init__.py | 6 - .../remote_interactive_shell/__init__.py | 3 - .../remote_interactive_shell/tunnel.py | 5 - .../remote_management/Protocols/SSH.py | 47 - .../remote_management/Protocols/WinRM.py | 47 - .../remote_management/Protocols/__init__.py | 6 - .../remote_management/__init__.py | 3 - .../remote_management/tunnel.py | 5 - .../remote_video/Protocols/RDP.py | 47 - .../remote_video/Protocols/VNC.py | 47 - .../remote_video/Protocols/WebRTC.py | 47 - .../remote_video/Protocols/__init__.py | 7 - .../Reverse_Tunnels/remote_video/__init__.py | 3 - .../Reverse_Tunnels/remote_video/tunnel.py | 5 - Data/Agent/Roles/role_ReverseTunnel.py | 939 ------------ Data/Agent/Roles/role_VpnShell.py | 167 ++ Data/Agent/Roles/role_WireGuardTunnel.py | 53 +- Data/Engine/Unit_Tests/test_reverse_tunnel.py | 90 -- .../test_reverse_tunnel_integration.py | 101 -- Data/Engine/config.py | 73 +- Data/Engine/database_migrations.py | 17 + Data/Engine/server.py | 20 +- .../Engine/services/API/devices/management.py | 140 ++ Data/Engine/services/API/devices/tunnel.py | 127 +- Data/Engine/services/VPN/__init__.py | 2 +- .../Engine/services/VPN/vpn_tunnel_service.py | 473 ++++++ Data/Engine/services/VPN/wireguard_server.py | 55 +- .../Agent/Reverse_Tunnels/__init__.py | 3 - .../Protocols/Bash.py | 78 - .../Protocols/Powershell.py | 139 -- .../Protocols/__init__.py | 6 - .../remote_interactive_shell/__init__.py | 3 - .../remote_management/Protocols/SSH.py | 73 - .../remote_management/Protocols/WinRM.py | 73 - .../remote_management/Protocols/__init__.py | 6 - .../remote_management/__init__.py | 3 - .../remote_video/Protocols/RDP.py | 73 - .../remote_video/Protocols/VNC.py | 73 - .../remote_video/Protocols/WebRTC.py | 73 - .../remote_video/Protocols/__init__.py | 7 - .../Reverse_Tunnels/remote_video/__init__.py | 3 - .../services/WebSocket/Agent/__init__.py | 10 - .../Agent/reverse_tunnel_orchestrator.py | 1361 ----------------- Data/Engine/services/WebSocket/__init__.py | 318 +--- Data/Engine/services/WebSocket/vpn_shell.py | 127 ++ .../src/Devices/Device_Details.jsx | 327 +++- .../src/Devices/ReverseTunnel/Powershell.jsx | 759 +++------ Docs/Codex/BOREALIS_AGENT.md | 5 +- Docs/Codex/BOREALIS_ENGINE.md | 7 +- Docs/Codex/REVERSE_TUNNELS.md | 129 +- Docs/Codex/Reverse_VPN_Tunnel_Deployment.md | 94 +- 56 files changed, 1786 insertions(+), 4778 deletions(-) delete mode 100644 Data/Agent/Roles/ReverseTunnel/__init__.py delete mode 100644 Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/tunnel.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/SSH.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/WinRM.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_management/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_management/tunnel.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/RDP.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/VNC.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/WebRTC.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/__init__.py delete mode 100644 Data/Agent/Roles/Reverse_Tunnels/remote_video/tunnel.py delete mode 100644 Data/Agent/Roles/role_ReverseTunnel.py create mode 100644 Data/Agent/Roles/role_VpnShell.py delete mode 100644 Data/Engine/Unit_Tests/test_reverse_tunnel.py delete mode 100644 Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py create mode 100644 Data/Engine/services/VPN/vpn_tunnel_service.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/SSH.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/WinRM.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/RDP.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/VNC.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/WebRTC.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/__init__.py delete mode 100644 Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py create mode 100644 Data/Engine/services/WebSocket/vpn_shell.py diff --git a/Data/Agent/Roles/ReverseTunnel/__init__.py b/Data/Agent/Roles/ReverseTunnel/__init__.py deleted file mode 100644 index 88b181a3..00000000 --- a/Data/Agent/Roles/ReverseTunnel/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Reverse tunnel protocol modules (placeholder package).""" - diff --git a/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py b/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py deleted file mode 100644 index 759b7ff2..00000000 --- a/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py +++ /dev/null @@ -1,190 +0,0 @@ -"""PowerShell channel implementation for reverse tunnel (Agent side).""" -from __future__ import annotations - -import asyncio -import sys -import subprocess -from typing import Any, Dict, Optional - -# Message types mirrored from the tunnel framing (kept local to avoid import cycles). -MSG_DATA = 0x05 -MSG_WINDOW_UPDATE = 0x06 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 - -# Close codes (mirrored from engine framing) -CLOSE_OK = 0 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class PowershellChannel: - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - self._reader_task = None - self._writer_task = None - self._stdin_queue: asyncio.Queue = asyncio.Queue() - self._proc: Optional[asyncio.subprocess.Process] = None - self._exit_code: Optional[int] = None - self._frame_cls = getattr(role, "_frame_cls", None) - - # ------------------------------------------------------------------ Helpers - def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0): - frame_cls = self._frame_cls - if frame_cls is None: - return None - try: - return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags) - except Exception: - return None - - async def _send_frame(self, frame) -> None: - if frame is None: - return - await self.role._send_frame(self.tunnel, frame) - - async def _send_close(self, code: int, reason: str) -> None: - try: - close_frame = getattr(self.role, "close_frame") - if callable(close_frame): - await self._send_frame(close_frame(self.channel_id, code, reason)) - return - except Exception: - pass - frame = self._make_frame( - MSG_CLOSE, - payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"), - ) - await self._send_frame(frame) - - def _powershell_argv(self) -> list: - preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None - shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe" - # Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin. - return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"] - - # ------------------------------------------------------------------ Lifecycle - async def start(self) -> None: - if sys.platform.lower().startswith("win") is False: - self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True) - await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only") - return - - argv = self._powershell_argv() - self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} mode=pipes") - - # Pipes (no PTY). - try: - self._proc = await asyncio.create_subprocess_exec( - *argv, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), - ) - except Exception as exc: - self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True) - await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed") - return - - self._reader_task = self.loop.create_task(self._pump_proc_stdout()) - self._writer_task = self.loop.create_task(self._pump_proc_stdin()) - self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)}") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type == MSG_DATA: - if frame.payload: - try: - self._stdin_queue.put_nowait(frame.payload) - except Exception: - await self._stdin_queue.put(frame.payload) - elif frame.msg_type == MSG_CONTROL: - await self._handle_control(frame.payload) - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - elif frame.msg_type == MSG_WINDOW_UPDATE: - # Reserved for back-pressure; ignore for now. - return - - async def _handle_control(self, payload: bytes) -> None: - # No-op for pipe mode; resize is not supported here. - return - - # -------------------- Pipe fallback pumps -------------------- - async def _pump_proc_stdout(self) -> None: - try: - while self._proc and not self._closed: - chunk = await self._proc.stdout.read(4096) - if not chunk: - break - frame = self._make_frame(MSG_DATA, payload=bytes(chunk)) - await self._send_frame(frame) - except asyncio.CancelledError: - pass - except Exception: - self.role._log("reverse_tunnel ps pipe stdout pump error", error=True) - finally: - if self._proc and not self._closed: - try: - self._exit_code = await self._proc.wait() - except Exception: - pass - await self.stop(reason="stdout_closed") - - async def _pump_proc_stdin(self) -> None: - try: - while self._proc and not self._closed: - data = await self._stdin_queue.get() - if self._closed or not self._proc or not self._proc.stdin: - break - try: - self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8")) - await self._proc.stdin.drain() - except Exception: - self.role._log("reverse_tunnel ps pipe stdin pump error", error=True) - break - except asyncio.CancelledError: - pass - except Exception: - self.role._log("reverse_tunnel ps pipe stdin pump error", error=True) - finally: - await self.stop(reason="stdin_closed") - - async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None: - if self._closed: - return - self._closed = True - if self._proc is not None: - try: - self._proc.terminate() - except Exception: - pass - current = asyncio.current_task() - if self._reader_task and self._reader_task is not current: - try: - self._reader_task.cancel() - except Exception: - pass - if self._writer_task and self._writer_task is not current: - try: - self._writer_task.cancel() - except Exception: - pass - # Include exit code in the close reason for debugging. - exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else "" - close_reason = (reason or "powershell_exit") + exit_suffix - # Always send CLOSE before socket teardown so engine/UI see the reason. - try: - await self._send_close(code, close_reason) - except Exception: - self.role._log("reverse_tunnel ps close send failed", error=True) - self.role._log( - f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={close_reason}" - ) diff --git a/Data/Agent/Roles/Reverse_Tunnels/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/__init__.py deleted file mode 100644 index d6e526bb..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Namespace package for reverse tunnel domains (Agent side).""" - -__all__ = ["remote_interactive_shell", "remote_management", "remote_video"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py b/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py deleted file mode 100644 index 53d51f98..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Placeholder Bash channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class BashChannel: - """Stub Bash handler that immediately reports unsupported.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - # Until Bash support is implemented, close the channel to free resources. - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - # Ignore payloads but acknowledge by stopping the channel to avoid leaks. - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "bash_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel bash channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["BashChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py b/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py deleted file mode 100644 index e9806fc3..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Expose the PowerShell channel under the domain path, with file-based import fallback.""" -from __future__ import annotations - -import importlib.util -from pathlib import Path - -powershell_module = None - -# Attempt package-relative import first -try: # pragma: no cover - best effort - from ....ReverseTunnel import tunnel_Powershell as powershell_module # type: ignore -except Exception: - powershell_module = None - -# Fallback: load directly from file path to survive non-package runtimes -if powershell_module is None: - try: - base = Path(__file__).resolve().parents[3] / "ReverseTunnel" / "tunnel_Powershell.py" - spec = importlib.util.spec_from_file_location("tunnel_Powershell", base) - if spec and spec.loader: - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) # type: ignore - powershell_module = module - except Exception: - powershell_module = None - -if powershell_module and hasattr(powershell_module, "PowershellChannel"): - PowershellChannel = powershell_module.PowershellChannel # type: ignore -else: # pragma: no cover - safety guard - class PowershellChannel: # type: ignore - def __init__(self, *args, **kwargs): - raise ImportError("PowerShell channel unavailable") - - -__all__ = ["PowershellChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py deleted file mode 100644 index 41b0a7e0..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Protocol handlers for interactive shell tunnels (Agent side).""" - -from .Powershell import PowershellChannel -from .Bash import BashChannel - -__all__ = ["PowershellChannel", "BashChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/__init__.py deleted file mode 100644 index 93484210..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Interactive shell domain (PowerShell/Bash) handlers.""" - -__all__ = ["tunnel", "Protocols"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/tunnel.py b/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/tunnel.py deleted file mode 100644 index f9793561..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_interactive_shell/tunnel.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Placeholder module for remote interactive shell tunnel domain (Agent side).""" - -DOMAIN_NAME = "remote-interactive-shell" - -__all__ = ["DOMAIN_NAME"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/SSH.py b/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/SSH.py deleted file mode 100644 index e3d69d55..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/SSH.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Placeholder SSH channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class SSHChannel: - """Stub SSH handler that marks the channel unsupported for now.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "ssh_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel ssh channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["SSHChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/WinRM.py b/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/WinRM.py deleted file mode 100644 index 3425bddb..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/WinRM.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Placeholder WinRM channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class WinRMChannel: - """Stub WinRM handler that marks the channel unsupported for now.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "winrm_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel winrm channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["WinRMChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/__init__.py deleted file mode 100644 index bcb8b897..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_management/Protocols/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Protocol handlers for remote management tunnels (Agent side).""" - -from .SSH import SSHChannel -from .WinRM import WinRMChannel - -__all__ = ["SSHChannel", "WinRMChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_management/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_management/__init__.py deleted file mode 100644 index 4ffe4c15..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_management/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Remote management domain (SSH/WinRM) handlers.""" - -__all__ = ["tunnel", "Protocols"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_management/tunnel.py b/Data/Agent/Roles/Reverse_Tunnels/remote_management/tunnel.py deleted file mode 100644 index c39826aa..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_management/tunnel.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Placeholder module for remote management domain (Agent side).""" - -DOMAIN_NAME = "remote-management" - -__all__ = ["DOMAIN_NAME"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/RDP.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/RDP.py deleted file mode 100644 index acda4c33..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/RDP.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Placeholder RDP channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class RDPChannel: - """Stub RDP handler that marks the channel unsupported for now.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "rdp_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel rdp channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["RDPChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/VNC.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/VNC.py deleted file mode 100644 index 5e37b8d2..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/VNC.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Placeholder VNC channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class VNCChannel: - """Stub VNC handler that marks the channel unsupported for now.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "vnc_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel vnc channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["VNCChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/WebRTC.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/WebRTC.py deleted file mode 100644 index 120c2300..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/WebRTC.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Placeholder WebRTC channel (Agent side).""" -from __future__ import annotations - -import asyncio -from typing import Any, Dict, Optional - -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class WebRTCChannel: - """Stub WebRTC handler that marks the channel unsupported for now.""" - - def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() - self._closed = False - - async def start(self) -> None: - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported") - - async def on_frame(self, frame) -> None: - if self._closed: - return - if frame.msg_type in (MSG_DATA, MSG_CONTROL): - await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported") - elif frame.msg_type == MSG_CLOSE: - await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") - - async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None: - if self._closed: - return - self._closed = True - try: - await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "webrtc_closed")) - except Exception: - pass - self.role._log(f"reverse_tunnel webrtc channel stopped channel={self.channel_id} reason={reason or 'closed'}") - - -__all__ = ["WebRTCChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/__init__.py deleted file mode 100644 index fdcfa7a0..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/Protocols/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Protocol handlers for remote video tunnels (Agent side).""" - -from .WebRTC import WebRTCChannel -from .RDP import RDPChannel -from .VNC import VNCChannel - -__all__ = ["WebRTCChannel", "RDPChannel", "VNCChannel"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/__init__.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/__init__.py deleted file mode 100644 index 9491d3d5..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Remote video/desktop domain (RDP/VNC/WebRTC) handlers.""" - -__all__ = ["tunnel", "Protocols"] diff --git a/Data/Agent/Roles/Reverse_Tunnels/remote_video/tunnel.py b/Data/Agent/Roles/Reverse_Tunnels/remote_video/tunnel.py deleted file mode 100644 index d491fb7a..00000000 --- a/Data/Agent/Roles/Reverse_Tunnels/remote_video/tunnel.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Placeholder module for remote video domain (Agent side).""" - -DOMAIN_NAME = "remote-video" - -__all__ = ["DOMAIN_NAME"] diff --git a/Data/Agent/Roles/role_ReverseTunnel.py b/Data/Agent/Roles/role_ReverseTunnel.py deleted file mode 100644 index 05c3257b..00000000 --- a/Data/Agent/Roles/role_ReverseTunnel.py +++ /dev/null @@ -1,939 +0,0 @@ -import asyncio -import base64 -import importlib.util -import json -import os -import struct -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, Optional -from urllib.parse import urlparse - -import aiohttp -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import ed25519 - -# Capture import errors for protocol handlers so we can report why they are missing. -PS_IMPORT_ERROR: Optional[str] = None -BASH_IMPORT_ERROR: Optional[str] = None -tunnel_SSH = None -tunnel_WinRM = None -tunnel_VNC = None -tunnel_RDP = None -tunnel_WebRTC = None -tunnel_Powershell = None -tunnel_Bash = None - - -def _load_protocol_module(module_name: str, rel_parts: list[str]) -> tuple[Optional[object], Optional[str]]: - """Load a protocol handler directly from a file path to survive non-package runtimes.""" - - base = Path(__file__).parent - path = base - for part in rel_parts: - path = path / part - if not path.exists(): - return None, f"path_missing:{path}" - try: - spec = importlib.util.spec_from_file_location(module_name, path) - if not spec or not spec.loader: - return None, "spec_failed" - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) # type: ignore - return module, None - except Exception as exc: # pragma: no cover - defensive - return None, repr(exc) - - -try: - from .Reverse_Tunnels.remote_interactive_shell.Protocols import Powershell as tunnel_Powershell # type: ignore -except Exception as exc: # pragma: no cover - best-effort logging only - PS_IMPORT_ERROR = repr(exc) - _module, _err = _load_protocol_module( - "tunnel_Powershell", - ["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Powershell.py"], - ) - if _module: - tunnel_Powershell = _module - PS_IMPORT_ERROR = None - else: - try: - from .ReverseTunnel import tunnel_Powershell # type: ignore # legacy fallback - PS_IMPORT_ERROR = None - except Exception as exc2: # pragma: no cover - diagnostic only - PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | legacy_fallback={exc2!r} | file_load_failed={_err}" - -try: - from .Reverse_Tunnels.remote_interactive_shell.Protocols import Bash as tunnel_Bash # type: ignore -except Exception as exc: # pragma: no cover - best-effort logging only - BASH_IMPORT_ERROR = repr(exc) - _module, _err = _load_protocol_module( - "tunnel_Bash", - ["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Bash.py"], - ) - if _module: - tunnel_Bash = _module - BASH_IMPORT_ERROR = None - else: - BASH_IMPORT_ERROR = f"{BASH_IMPORT_ERROR} | file_load_failed={_err}" -try: - from .Reverse_Tunnels.remote_management.Protocols import SSH as tunnel_SSH # type: ignore -except Exception: - _module, _err = _load_protocol_module( - "tunnel_SSH", - ["Reverse_Tunnels", "remote_management", "Protocols", "SSH.py"], - ) - tunnel_SSH = _module -try: - from .Reverse_Tunnels.remote_management.Protocols import WinRM as tunnel_WinRM # type: ignore -except Exception: - _module, _err = _load_protocol_module( - "tunnel_WinRM", - ["Reverse_Tunnels", "remote_management", "Protocols", "WinRM.py"], - ) - tunnel_WinRM = _module -try: - from .Reverse_Tunnels.remote_video.Protocols import VNC as tunnel_VNC # type: ignore -except Exception: - _module, _err = _load_protocol_module( - "tunnel_VNC", - ["Reverse_Tunnels", "remote_video", "Protocols", "VNC.py"], - ) - tunnel_VNC = _module -try: - from .Reverse_Tunnels.remote_video.Protocols import RDP as tunnel_RDP # type: ignore -except Exception: - _module, _err = _load_protocol_module( - "tunnel_RDP", - ["Reverse_Tunnels", "remote_video", "Protocols", "RDP.py"], - ) - tunnel_RDP = _module -try: - from .Reverse_Tunnels.remote_video.Protocols import WebRTC as tunnel_WebRTC # type: ignore -except Exception: - _module, _err = _load_protocol_module( - "tunnel_WebRTC", - ["Reverse_Tunnels", "remote_video", "Protocols", "WebRTC.py"], - ) - tunnel_WebRTC = _module - -ROLE_NAME = "reverse_tunnel" -ROLE_CONTEXTS = ["interactive", "system"] - -# Message types (keep in sync with Engine service) -MSG_CONNECT = 0x01 -MSG_CONNECT_ACK = 0x02 -MSG_CHANNEL_OPEN = 0x03 -MSG_CHANNEL_ACK = 0x04 -MSG_DATA = 0x05 -MSG_WINDOW_UPDATE = 0x06 -MSG_HEARTBEAT = 0x07 -MSG_CLOSE = 0x08 -MSG_CONTROL = 0x09 - -# Close codes -CLOSE_OK = 0 -CLOSE_IDLE_TIMEOUT = 1 -CLOSE_GRACE_EXPIRED = 2 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AUTH_FAILED = 4 -CLOSE_SERVER_SHUTDOWN = 5 -CLOSE_AGENT_SHUTDOWN = 6 -CLOSE_DOMAIN_LIMIT = 7 -CLOSE_UNEXPECTED_DISCONNECT = 8 - -FRAME_VERSION = 1 -FRAME_HEADER_STRUCT = struct.Struct(" bytes: - payload = self.payload or b"" - header = FRAME_HEADER_STRUCT.pack( - self.version, - self.msg_type, - self.flags, - self.reserved, - int(self.channel_id), - len(payload), - ) - return header + payload - - -def decode_frame(raw: bytes) -> TunnelFrame: - if len(raw) < FRAME_HEADER_STRUCT.size: - raise ValueError("frame_too_small") - version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0) - if version != FRAME_VERSION: - raise ValueError(f"unsupported_version:{version}") - if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length: - raise ValueError("invalid_length") - payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length] - if len(payload) != length: - raise ValueError("length_mismatch") - return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved) - - -def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame: - flags = 0x1 if is_ack else 0x0 - return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"") - - -def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame: - payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8") - return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload) - - -def _norm_text(value: Any) -> str: - if value is None: - return "" - try: - return str(value).strip() - except Exception: - return "" - - -def _is_literal_ip(value: str) -> bool: - try: - import ipaddress - - ipaddress.ip_address(value.strip().strip("[]")) - return True - except Exception: - return False - - -@dataclass -class ActiveTunnel: - tunnel_id: str - domain: str - protocol: str - port: int - token: str - url: str - heartbeat_seconds: int - idle_seconds: int - grace_seconds: int - expires_at: Optional[int] - signing_key_hint: Optional[str] = None - session: Optional[aiohttp.ClientSession] = None - websocket: Optional[aiohttp.ClientWebSocketResponse] = None - tasks: list = field(default_factory=list) - send_queue: asyncio.Queue = field(default_factory=asyncio.Queue) - channels: Dict[int, Any] = field(default_factory=dict) - last_activity: float = field(default_factory=lambda: time.time()) - connected: bool = False - stopping: bool = False - stop_reason: Optional[str] = None - stop_origin: Optional[str] = None - - -class BaseChannel: - """Placeholder channel handler; protocol-specific handlers plug in later.""" - - def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]): - self.role = role - self.tunnel = tunnel - self.channel_id = channel_id - self.metadata = metadata or {} - - async def start(self) -> None: - # Nothing to prime for placeholder channels. - return - - async def on_frame(self, frame: TunnelFrame) -> None: - # Drop frames until protocol module is provided. - return - - async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None: - await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason)) - - -class Role: - def __init__(self, ctx): - self.ctx = ctx - self.sio = ctx.sio - self.loop = ctx.loop or asyncio.get_event_loop() - self.hooks = ctx.hooks or {} - self._http_client_factory = self.hooks.get("http_client") - self._log_hook = self.hooks.get("log_agent") - self._active: Dict[str, ActiveTunnel] = {} - self._domain_claims: Dict[str, str] = {} - self._domain_limits: Dict[str, Optional[int]] = { - "remote-interactive-shell": 2, - "remote-management": 1, - "remote-video": 2, - # Legacy / protocol fallbacks - "ps": 2, - "rdp": 1, - "vnc": 1, - "webrtc": 2, - "ssh": None, - "winrm": None, - } - self._default_heartbeat = 20 - self._protocol_handlers: Dict[str, Any] = {} - self._frame_cls = TunnelFrame - self.close_frame = close_frame - try: - if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"): - self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel - module_path = getattr(tunnel_Powershell, "__file__", None) - self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}") - else: - hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" - module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" - exists_hint = f" exists={module_path.exists()}" - self._log( - f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}", - error=True, - ) - except Exception as exc: - self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True) - try: - if tunnel_Bash and hasattr(tunnel_Bash, "BashChannel"): - self._protocol_handlers["bash"] = tunnel_Bash.BashChannel - module_path = getattr(tunnel_Bash, "__file__", None) - self._log(f"reverse_tunnel bash handler registered (BashChannel) module={module_path}") - elif BASH_IMPORT_ERROR: - self._log(f"reverse_tunnel bash handler NOT registered (missing module/class) import_error={BASH_IMPORT_ERROR}", error=True) - except Exception as exc: - self._log(f"reverse_tunnel bash handler registration failed: {exc}", error=True) - try: - if tunnel_SSH and hasattr(tunnel_SSH, "SSHChannel"): - self._protocol_handlers["ssh"] = tunnel_SSH.SSHChannel - if tunnel_WinRM and hasattr(tunnel_WinRM, "WinRMChannel"): - self._protocol_handlers["winrm"] = tunnel_WinRM.WinRMChannel - if tunnel_VNC and hasattr(tunnel_VNC, "VNCChannel"): - self._protocol_handlers["vnc"] = tunnel_VNC.VNCChannel - if tunnel_RDP and hasattr(tunnel_RDP, "RDPChannel"): - self._protocol_handlers["rdp"] = tunnel_RDP.RDPChannel - if tunnel_WebRTC and hasattr(tunnel_WebRTC, "WebRTCChannel"): - self._protocol_handlers["webrtc"] = tunnel_WebRTC.WebRTCChannel - except Exception as exc: - self._log(f"reverse_tunnel protocol handler registration failed: {exc}", error=True) - - # ------------------------------------------------------------------ Logging - def _log(self, message: str, *, error: bool = False) -> None: - fname = "reverse_tunnel.log" - try: - if callable(self._log_hook): - self._log_hook(message, fname=fname) - if error: - self._log_hook(message, fname="agent.error.log") - except Exception: - pass - - # ------------------------------------------------------------------ Event wiring - def register_events(self): - @self.sio.on("reverse_tunnel_start") - async def _reverse_tunnel_start(payload): - await self._handle_tunnel_start(payload) - - @self.sio.on("reverse_tunnel_stop") - async def _reverse_tunnel_stop(payload): - tid = "" - if isinstance(payload, dict): - tid = _norm_text(payload.get("tunnel_id")) - await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop") - - # ------------------------------------------------------------------ Token helpers - def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]: - if not token: - raise ValueError("token_missing") - - def _b64decode(segment: str) -> bytes: - padding = "=" * (-len(segment) % 4) - return base64.urlsafe_b64decode(segment + padding) - - parts = token.split(".") - payload_segment = parts[0] - payload_bytes = _b64decode(payload_segment) - try: - payload = json.loads(payload_bytes.decode("utf-8")) - except Exception as exc: - raise ValueError("token_decode_error") from exc - - if len(parts) == 2: - candidates = [] - hint = _norm_text(signing_key_hint) - if hint: - candidates.append(hint) - client = self._http_client() - if client and hasattr(client, "load_server_signing_key"): - try: - stored = client.load_server_signing_key() - except Exception: - stored = None - if isinstance(stored, str) and stored.strip(): - candidates.append(stored.strip()) - - signature = _b64decode(parts[1]) - verified = False - for candidate in candidates: - try: - key_bytes = base64.b64decode(candidate, validate=True) - public_key = serialization.load_der_public_key(key_bytes) - except Exception: - continue - if not isinstance(public_key, ed25519.Ed25519PublicKey): - continue - try: - public_key.verify(signature, payload_bytes) - verified = True - if client and hasattr(client, "store_server_signing_key"): - try: - client.store_server_signing_key(candidate) - except Exception: - pass - break - except Exception: - continue - if not verified: - raise ValueError("token_signature_invalid") - return payload - - def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]: - payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint) - - def _matches(expected: str, actual: Any) -> bool: - return _norm_text(expected).lower() == _norm_text(actual).lower() - - if expected_agent and not _matches(expected_agent, payload.get("agent_id")): - raise ValueError("token_agent_mismatch") - if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")): - raise ValueError("token_id_mismatch") - if expected_domain and not _matches(expected_domain, payload.get("domain")): - raise ValueError("token_domain_mismatch") - if expected_protocol and not _matches(expected_protocol, payload.get("protocol")): - raise ValueError("token_protocol_mismatch") - - expires_at = payload.get("expires_at") - try: - expires_ts = int(expires_at) if expires_at is not None else None - except Exception: - expires_ts = None - if expires_ts is not None and expires_ts < int(time.time()): - raise ValueError("token_expired") - return payload - - # ------------------------------------------------------------------ Utility - def _http_client(self): - try: - if callable(self._http_client_factory): - return self._http_client_factory() - except Exception: - return None - return None - - def _port_allowed(self, port: int) -> bool: - return isinstance(port, int) and 1 <= port <= 65535 - - def _domain_allowed(self, domain: str) -> bool: - limit = self._domain_limits.get(domain) - if limit is None: - return True - active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping] - pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active] - count = len(set(active + pending)) - return count < limit - - def _build_ws_url(self, port: int) -> str: - client = self._http_client() - if client: - try: - client.refresh_base_url() - except Exception: - pass - base_url = getattr(client, "base_url", "") or "https://localhost:5000" - else: - base_url = "https://localhost:5000" - parsed = urlparse(base_url) - host = parsed.hostname or "localhost" - scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws" - return f"{scheme}://{host}:{port}/" - - def _ssl_context(self, host: str): - client = self._http_client() - verify = getattr(getattr(client, "session", None), "verify", True) - if verify is False: - return False - if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"): - try: - ctx = client.key_store.build_ssl_context() - if ctx and _is_literal_ip(host): - ctx.check_hostname = False - return ctx - except Exception: - return None - return None - - async def _emit_status(self, payload: Dict[str, Any]) -> None: - try: - await self.sio.emit("reverse_tunnel_status", payload) - except Exception: - pass - - def _mark_activity(self, tunnel: ActiveTunnel) -> None: - tunnel.last_activity = time.time() - - # ------------------------------------------------------------------ Event handlers - async def _handle_tunnel_start(self, payload: Any) -> None: - if not isinstance(payload, dict): - self._log("reverse_tunnel_start ignored: payload not a dict", error=True) - return - tunnel_id = _norm_text(payload.get("tunnel_id")) - token = _norm_text(payload.get("token")) - port = payload.get("port") or payload.get("assigned_port") - protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps" - domain = _norm_text(payload.get("domain") or protocol).lower() or protocol - heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20) - idle_seconds = int(payload.get("idle_seconds") or 3600) - grace_seconds = int(payload.get("grace_seconds") or 3600) - signing_key_hint = _norm_text(payload.get("signing_key")) - agent_hint = _norm_text(payload.get("agent_id")) - - # Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts). - if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower(): - return - - if not token: - self._log("reverse_tunnel_start rejected: missing token", error=True) - await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"}) - return - - if not tunnel_id: - tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias - - try: - claims = self._validate_token( - token, - expected_agent=self.ctx.agent_id, - expected_domain=domain, - expected_protocol=protocol, - expected_tunnel=tunnel_id, - signing_key_hint=signing_key_hint, - ) - except Exception as exc: - if str(exc) == "token_agent_mismatch": - # Broadcast hit the wrong agent context; ignore quietly. - return - self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True) - await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"}) - return - - tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id) - if not tunnel_id: - self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True) - return - - try: - port = int(port or claims.get("assigned_port") or 0) - except Exception: - port = 0 - if not self._port_allowed(port): - self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True) - await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"}) - return - - domain = _norm_text(claims.get("domain") or domain).lower() - protocol = _norm_text(claims.get("protocol") or protocol).lower() - expires_at = claims.get("expires_at") - - if tunnel_id in self._active: - self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}") - return - if not self._domain_allowed(domain): - self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True) - await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"}) - return - - url = self._build_ws_url(port) - parsed = urlparse(url) - heartbeat_seconds = max(5, min(heartbeat_seconds, 120)) - idle_seconds = max(30, idle_seconds) - grace_seconds = max(60, grace_seconds) - - tunnel = ActiveTunnel( - tunnel_id=tunnel_id, - domain=domain, - protocol=protocol, - port=port, - token=token, - url=url, - heartbeat_seconds=heartbeat_seconds, - idle_seconds=idle_seconds, - grace_seconds=grace_seconds, - expires_at=int(expires_at) if expires_at is not None else None, - signing_key_hint=signing_key_hint or None, - ) - - self._active[tunnel_id] = tunnel - self._domain_claims[domain] = tunnel_id - self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}") - await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url}) - task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost")) - tunnel.tasks.append(task) - - # ------------------------------------------------------------------ Core tunnel handling - async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None: - ssl_ctx = self._ssl_context(host) - timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None) - self._log( - f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} " - f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}" - ) - try: - tunnel.session = aiohttp.ClientSession(timeout=timeout) - tunnel.websocket = await tunnel.session.ws_connect( - tunnel.url, - ssl=ssl_ctx, - heartbeat=None, - max_msg_size=0, - timeout=timeout, - ) - self._mark_activity(tunnel) - self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}") - await tunnel.websocket.send_bytes( - TunnelFrame( - msg_type=MSG_CONNECT, - channel_id=0, - payload=json.dumps( - { - "agent_id": self.ctx.agent_id, - "tunnel_id": tunnel.tunnel_id, - "token": tunnel.token, - "protocol": tunnel.protocol, - "domain": tunnel.domain, - "version": FRAME_VERSION, - }, - separators=(",", ":"), - ).encode("utf-8"), - ).encode() - ) - self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}") - - sender = self.loop.create_task(self._pump_sender(tunnel)) - receiver = self.loop.create_task(self._pump_receiver(tunnel)) - heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel)) - watchdog = self.loop.create_task(self._watchdog(tunnel)) - tunnel.tasks.extend([sender, receiver, heartbeats, watchdog]) - task_labels = { - sender: "sender", - receiver: "receiver", - heartbeats: "heartbeat", - watchdog: "watchdog", - } - done, pending = await asyncio.wait(task_labels.keys(), return_when=asyncio.FIRST_COMPLETED) - for finished in done: - label = task_labels.get(finished) or "unknown" - exc_text = "" - try: - exc_obj = finished.exception() - except asyncio.CancelledError: - exc_obj = None - exc_text = " (cancelled)" - except Exception as exc: # pragma: no cover - defensive logging - exc_obj = exc - if exc_obj: - exc_text = f" (exc={exc_obj!r})" - if not tunnel.stop_reason: - tunnel.stop_reason = f"{label}_stopped{exc_text}" - if not tunnel.stop_origin: - tunnel.stop_origin = label - self._log( - f"reverse_tunnel task completed tunnel_id={tunnel.tunnel_id} task={label} stop_reason={tunnel.stop_reason}{exc_text}" - ) - if pending: - try: - self._log( - "reverse_tunnel pending tasks after first completion tunnel_id=%s pending=%s", - # Represent pending tasks by label for debugging. - ) - except Exception: - pass - except Exception as exc: - self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True) - await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"}) - finally: - try: - ws = tunnel.websocket - self._log( - f"reverse_tunnel ws closing tunnel_id={tunnel.tunnel_id} " - f"code={getattr(ws, 'close_code', None)} reason={getattr(ws, 'close_reason', None)}" - ) - except Exception: - pass - await self._shutdown_tunnel(tunnel) - - async def _pump_sender(self, tunnel: ActiveTunnel) -> None: - try: - while tunnel.websocket and not tunnel.websocket.closed: - frame: TunnelFrame = await tunnel.send_queue.get() - try: - await tunnel.websocket.send_bytes(frame.encode()) - self._mark_activity(tunnel) - self._log( - f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} " - f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" - ) - except Exception: - if not tunnel.stop_reason: - tunnel.stop_reason = "sender_error" - break - except asyncio.CancelledError: - pass - except Exception: - if not tunnel.stop_reason: - tunnel.stop_reason = "sender_failed" - self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True) - finally: - self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}") - - async def _pump_receiver(self, tunnel: ActiveTunnel) -> None: - ws = tunnel.websocket - if ws is None: - return - try: - async for msg in ws: - if msg.type == aiohttp.WSMsgType.BINARY: - try: - frame = decode_frame(msg.data) - except Exception: - self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True) - continue - self._mark_activity(tunnel) - await self._handle_frame(tunnel, frame) - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE): - self._log( - f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} " - f"code={ws.close_code} reason={ws.close_reason}" - ) - tunnel.stop_reason = ws.close_reason or "ws_closed" - break - except asyncio.CancelledError: - pass - except Exception: - if not tunnel.stop_reason: - tunnel.stop_reason = "receiver_failed" - self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True) - finally: - self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}") - # If no stop_reason was set, emit a CLOSE so engine/UI see a reason. - if not tunnel.stop_reason: - try: - await self._send_frame(tunnel, close_frame(0, CLOSE_UNEXPECTED_DISCONNECT, "receiver_stop")) - tunnel.stop_reason = "receiver_stop" - except Exception: - pass - - async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None: - try: - while tunnel.websocket and not tunnel.websocket.closed: - await asyncio.sleep(tunnel.heartbeat_seconds) - await self._send_frame(tunnel, heartbeat_frame()) - self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}") - except asyncio.CancelledError: - pass - except Exception: - if not tunnel.stop_reason: - tunnel.stop_reason = "heartbeat_failed" - self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True) - finally: - self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}") - - async def _watchdog(self, tunnel: ActiveTunnel) -> None: - try: - while tunnel.websocket and not tunnel.websocket.closed: - await asyncio.sleep(10) - now = time.time() - if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds: - await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout")) - tunnel.stop_reason = tunnel.stop_reason or "idle_timeout" - self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}") - break - if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds: - await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired")) - tunnel.stop_reason = tunnel.stop_reason or "grace_expired" - self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}") - break - except asyncio.CancelledError: - pass - except Exception: - if not tunnel.stop_reason: - tunnel.stop_reason = "watchdog_failed" - self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True) - finally: - self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}") - - async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: - self._log( - f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} " - f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}" - ) - if frame.msg_type == MSG_HEARTBEAT: - if frame.flags & 0x1: - self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}") - return - await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True)) - return - if frame.msg_type == MSG_CONNECT_ACK: - tunnel.connected = True - await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"}) - self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}") - return - if frame.msg_type == MSG_CHANNEL_OPEN: - await self._handle_channel_open(tunnel, frame) - return - if frame.msg_type == MSG_CLOSE: - try: - reason = json.loads(frame.payload.decode("utf-8")) - except Exception: - reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"} - tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None - await self._shutdown_tunnel(tunnel, send_close=False) - return - - handler = tunnel.channels.get(frame.channel_id) - if handler: - try: - await handler.on_frame(frame) - except Exception: - self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True) - else: - if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE): - await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel")) - - async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: - try: - payload = json.loads(frame.payload.decode("utf-8")) - except Exception: - payload = {} - protocol = _norm_text(payload.get("protocol") or tunnel.protocol) - metadata = payload.get("metadata") if isinstance(payload, dict) else {} - - if frame.channel_id in tunnel.channels: - await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists")) - return - - if protocol.lower() == "ps" and "ps" not in self._protocol_handlers: - hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else "" - module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py" - exists_hint = f" exists={module_path.exists()}" - self._log( - f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}", - error=True, - ) - - handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel - try: - handler = handler_cls(self, tunnel, frame.channel_id, metadata) - except Exception: - self._log(f"reverse_tunnel channel handler fallback to BaseChannel protocol={protocol}", error=True) - handler = BaseChannel(self, tunnel, frame.channel_id, metadata) - tunnel.channels[frame.channel_id] = handler - await handler.start() - await self._send_frame( - tunnel, - TunnelFrame( - msg_type=MSG_CHANNEL_ACK, - channel_id=frame.channel_id, - payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"), - ), - ) - self._log( - f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} " - f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}" - ) - - async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: - if tunnel.stopping and getattr(frame, "msg_type", None) != MSG_CLOSE: - return - try: - tunnel.send_queue.put_nowait(frame) - except Exception: - await tunnel.send_queue.put(frame) - - async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None: - tunnel = self._active.get(tunnel_id) - if not tunnel: - return - if not tunnel.stop_origin: - tunnel.stop_origin = "stop_tunnel" - self._log(f"reverse_tunnel stop_tunnel requested tunnel_id={tunnel_id} code={code} reason={reason}") - if not tunnel.stop_reason: - tunnel.stop_reason = reason or "requested" - await self._send_frame(tunnel, close_frame(0, code, reason)) - await self._shutdown_tunnel(tunnel, send_close=False) - - async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None: - if tunnel.stopping: - return - tunnel.stopping = True - reason_text = tunnel.stop_reason or "closed" - if not tunnel.stop_reason: - tunnel.stop_reason = reason_text - if not tunnel.stop_origin: - tunnel.stop_origin = "shutdown" - self._log( - f"reverse_tunnel shutdown start tunnel_id={tunnel.tunnel_id} stop_reason={tunnel.stop_reason} " - f"stop_origin={tunnel.stop_origin} ws_closed={getattr(tunnel.websocket, 'closed', None)}" - ) - # Stop all channels first so CLOSE frames (with reasons) are sent upstream. - for handler in list(tunnel.channels.values()): - try: - await handler.stop(code=CLOSE_UNEXPECTED_DISCONNECT, reason=reason_text or "tunnel_shutdown") - except Exception: - pass - if send_close: - close_payload = close_frame(0, CLOSE_AGENT_SHUTDOWN, reason_text or "agent_shutdown") - try: - await self._send_frame(tunnel, close_payload) - # Give the sender loop a brief window to flush the CLOSE upstream. - await asyncio.sleep(0.05) - except Exception: - pass - # Fallback: if sender task died, try sending directly on the websocket. - try: - if tunnel.websocket and not tunnel.websocket.closed: - await tunnel.websocket.send_bytes(close_payload.encode()) - except Exception: - pass - for task in list(tunnel.tasks): - try: - task.cancel() - except Exception: - pass - if tunnel.websocket is not None: - try: - message = (reason_text or "agent_shutdown").encode("utf-8", "ignore")[:120] - await tunnel.websocket.close(message=message) - except Exception: - pass - if tunnel.session is not None: - try: - await tunnel.session.close() - except Exception: - pass - self._active.pop(tunnel.tunnel_id, None) - if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id: - self._domain_claims.pop(tunnel.domain, None) - self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}") - await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"}) - - # ------------------------------------------------------------------ Lifecycle - def stop_all(self): - for tunnel_id in list(self._active.keys()): - try: - self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown")) - except Exception: - pass diff --git a/Data/Agent/Roles/role_VpnShell.py b/Data/Agent/Roles/role_VpnShell.py new file mode 100644 index 00000000..47b01d01 --- /dev/null +++ b/Data/Agent/Roles/role_VpnShell.py @@ -0,0 +1,167 @@ +# ====================================================== +# Data\Agent\Roles\role_VpnShell.py +# Description: PowerShell TCP server for VPN shell access (Engine connects over WireGuard /32). +# +# API Endpoints (if applicable): None +# ====================================================== + +"""VPN PowerShell server for the WireGuard tunnel.""" + +from __future__ import annotations + +import base64 +import json +import socket +import subprocess +import threading +import time +from pathlib import Path +from typing import Any, Optional +import os + +ROLE_NAME = "VpnShell" +ROLE_CONTEXTS = ["system"] + + +def _log_path() -> Path: + root = Path(__file__).resolve().parents[2] / "Logs" + root.mkdir(parents=True, exist_ok=True) + return root / "reverse_tunnel.log" + + +def _write_log(message: str) -> None: + ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) + try: + _log_path().open("a", encoding="utf-8").write(f"[{ts}] [vpn-shell] {message}\n") + except Exception: + pass + + +def _b64encode(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").strip() + + +def _b64decode(value: str) -> bytes: + return base64.b64decode(value.encode("ascii")) + +def _resolve_shell_port() -> int: + raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT") + try: + value = int(raw) if raw is not None else 47001 + except Exception: + value = 47001 + if value < 1 or value > 65535: + return 47001 + return value + + +class ShellSession: + def __init__(self, conn: socket.socket, address: tuple[str, int]) -> None: + self.conn = conn + self.address = address + self.proc: Optional[subprocess.Popen] = None + self._stop = threading.Event() + + def start(self) -> None: + self.proc = subprocess.Popen( + ["powershell.exe", "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + bufsize=0, + ) + threading.Thread(target=self._reader_loop, daemon=True).start() + self._writer_loop() + + def _reader_loop(self) -> None: + if not self.proc or not self.proc.stdout: + return + try: + while not self._stop.is_set(): + chunk = self.proc.stdout.readline() + if not chunk: + break + payload = json.dumps({"type": "stdout", "data": _b64encode(chunk)}) + self.conn.sendall(payload.encode("utf-8") + b"\n") + except Exception: + pass + + def _writer_loop(self) -> None: + buffer = b"" + try: + while not self._stop.is_set(): + data = self.conn.recv(4096) + if not data: + break + buffer += data + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if not line: + continue + try: + msg = json.loads(line.decode("utf-8")) + except Exception: + continue + if msg.get("type") == "stdin": + payload = msg.get("data") or "" + if self.proc and self.proc.stdin: + try: + self.proc.stdin.write(_b64decode(str(payload))) + self.proc.stdin.flush() + except Exception: + pass + if msg.get("type") == "close": + self._stop.set() + break + finally: + self.close() + + def close(self) -> None: + self._stop.set() + try: + self.conn.close() + except Exception: + pass + if self.proc: + try: + self.proc.terminate() + except Exception: + pass + + +class ShellServer: + def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None: + self.host = host + self.port = port or _resolve_shell_port() + self._thread = threading.Thread(target=self._serve, daemon=True) + self._thread.start() + _write_log(f"VPN shell server listening on {self.host}:{self.port}") + + def _serve(self) -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server: + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((self.host, self.port)) + server.listen(5) + while True: + conn, addr = server.accept() + remote_ip = addr[0] + if not remote_ip.startswith("10.255."): + _write_log(f"Rejected shell connection from {remote_ip}") + conn.close() + continue + _write_log(f"Accepted shell connection from {remote_ip}") + session = ShellSession(conn, addr) + threading.Thread(target=session.start, daemon=True).start() + + +class Role: + def __init__(self, ctx) -> None: + self.ctx = ctx + self.server = ShellServer() + + def register_events(self) -> None: + return + + def stop_all(self) -> None: + return diff --git a/Data/Agent/Roles/role_WireGuardTunnel.py b/Data/Agent/Roles/role_WireGuardTunnel.py index 47f2cb6a..c1f690d0 100644 --- a/Data/Agent/Roles/role_WireGuardTunnel.py +++ b/Data/Agent/Roles/role_WireGuardTunnel.py @@ -9,13 +9,15 @@ This role prepares the WireGuard client config, manages a single active session, enforces idle teardown, and logs lifecycle events to -Agent/Logs/reverse_tunnel.log. It does not yet bind to engine signals; higher -layers should call start_session/stop_session with the issued config/token. +Agent/Logs/reverse_tunnel.log. It binds to Engine Socket.IO events +(`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop +the client session with the issued config/token. """ from __future__ import annotations import base64 +import json import os import subprocess import threading @@ -26,6 +28,7 @@ from typing import Any, Dict, Optional from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 +from signature_utils import verify_and_store_script_signature ROLE_NAME = "WireGuardTunnel" ROLE_CONTEXTS = ["system"] @@ -95,6 +98,8 @@ class SessionConfig: allowed_ports: str idle_seconds: int = 900 preshared_key: Optional[str] = None + client_private_key: Optional[str] = None + client_public_key: Optional[str] = None class WireGuardClient: @@ -122,17 +127,35 @@ class WireGuardClient: return candidate return "wireguard.exe" - def _validate_token(self, token: Dict[str, Any]) -> None: - required = ("agent_id", "tunnel_id", "expires_at") + def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None: + payload = dict(token or {}) + signature = payload.pop("signature", None) + signing_key = payload.pop("signing_key", None) + sig_alg = payload.pop("sig_alg", None) + + required = ("agent_id", "tunnel_id", "expires_at", "port") missing = [field for field in required if field not in token or token[field] in ("", None)] if missing: raise ValueError(f"Missing token fields: {', '.join(missing)}") try: - exp = float(token["expires_at"]) + exp = float(payload["expires_at"]) except Exception: raise ValueError("Invalid token expiry") if exp <= time.time(): raise ValueError("Token expired") + try: + port = int(payload["port"]) + except Exception: + raise ValueError("Invalid token port") + if port < 1 or port > 65535: + raise ValueError("Invalid token port") + + if signature: + if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"): + raise ValueError("Unsupported token signature algorithm") + payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key): + raise ValueError("Token signature invalid") def _run(self, args: list[str]) -> tuple[int, str, str]: try: @@ -142,9 +165,10 @@ class WireGuardClient: return 1, "", str(exc) def _render_config(self, session: SessionConfig) -> str: + private_key = session.client_private_key or self._client_keys["private"] lines = [ "[Interface]", - f"PrivateKey = {self._client_keys['private']}", + f"PrivateKey = {private_key}", f"Address = {session.virtual_ip}", "", "[Peer]", @@ -172,13 +196,13 @@ class WireGuardClient: t.start() self._idle_thread = t - def start_session(self, session: SessionConfig) -> None: + def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None: if self.session: _write_log("Rejecting start_session: existing session already active.") return try: - self._validate_token(session.token) + self._validate_token(session.token, signing_client=signing_client) except Exception as exc: _write_log(f"Refusing to start WireGuard session: {exc}") return @@ -243,6 +267,7 @@ class Role: self.client = client hooks = getattr(ctx, "hooks", {}) or {} self._log_hook = hooks.get("log_agent") + self._http_client_factory = hooks.get("http_client") def _log(self, message: str, *, error: bool = False) -> None: if callable(self._log_hook): @@ -254,6 +279,14 @@ class Role: pass _write_log(message) + def _http_client(self) -> Optional[Any]: + try: + if callable(self._http_client_factory): + return self._http_client_factory() + except Exception: + return None + return None + def _build_session(self, payload: Any) -> Optional[SessionConfig]: if not isinstance(payload, dict): self._log("WireGuard start payload missing/invalid.", error=True) @@ -299,6 +332,8 @@ class Role: allowed_ports=allowed_ports, idle_seconds=idle_seconds, preshared_key=payload.get("preshared_key"), + client_private_key=payload.get("client_private_key"), + client_public_key=payload.get("client_public_key"), ) def register_events(self) -> None: @@ -310,7 +345,7 @@ class Role: if not session: return self._log("WireGuard start request received.") - self.client.start_session(session) + self.client.start_session(session, signing_client=self._http_client()) @sio.on("vpn_tunnel_stop") async def _vpn_tunnel_stop(payload): diff --git a/Data/Engine/Unit_Tests/test_reverse_tunnel.py b/Data/Engine/Unit_Tests/test_reverse_tunnel.py deleted file mode 100644 index 9f4b14f1..00000000 --- a/Data/Engine/Unit_Tests/test_reverse_tunnel.py +++ /dev/null @@ -1,90 +0,0 @@ -# ====================================================== -# Data\Engine\Unit_Tests\test_reverse_tunnel.py -# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit). -# -# API Endpoints (if applicable): -# - POST /api/tunnel/request -# ====================================================== - -from __future__ import annotations - -import base64 -import json - -import pytest - -from .conftest import EngineTestHarness - - -def _client_with_admin_session(harness: EngineTestHarness): - client = harness.app.test_client() - with client.session_transaction() as sess: - sess["username"] = "admin" - sess["role"] = "Admin" - return client - - -def _decode_token_segment(token: str) -> dict: - """Decode the unsigned payload segment from the tunnel token.""" - if not token: - return {} - segment = token.split(".")[0] - padding = "=" * (-len(segment) % 4) - raw = base64.urlsafe_b64decode(segment + padding) - try: - return json.loads(raw.decode("utf-8")) - except Exception: - return {} - - -@pytest.mark.parametrize("agent_id", ["test-device-agent"]) -def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None: - client = _client_with_admin_session(engine_harness) - resp = client.post( - "/api/tunnel/request", - json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"}, - ) - assert resp.status_code == 200 - payload = resp.get_json() - assert payload["agent_id"] == agent_id - assert payload["protocol"] == "ps" - assert payload["domain"] == "ps" - assert isinstance(payload["port"], int) and payload["port"] >= 30000 - assert payload.get("token") - - claims = _decode_token_segment(payload["token"]) - assert claims.get("agent_id") == agent_id - assert claims.get("protocol") == "ps" - assert claims.get("domain") == "ps" - assert claims.get("tunnel_id") == payload["tunnel_id"] - assert claims.get("assigned_port") == payload["port"] - - -def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None: - client = _client_with_admin_session(engine_harness) - first = client.post( - "/api/tunnel/request", - json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, - ) - assert first.status_code == 200 - - second = client.post( - "/api/tunnel/request", - json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, - ) - assert second.status_code == 409 - data = second.get_json() - assert data.get("error") == "domain_limit" - - -def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None: - client = _client_with_admin_session(engine_harness) - resp = client.post( - "/api/tunnel/request", - json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"}, - ) - assert resp.status_code == 200 - payload = resp.get_json() - assert payload.get("idle_seconds") and payload["idle_seconds"] > 0 - assert payload.get("grace_seconds") and payload["grace_seconds"] > 0 - assert payload.get("expires_at") and int(payload["expires_at"]) > 0 diff --git a/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py b/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py deleted file mode 100644 index c1eab6f3..00000000 --- a/Data/Engine/Unit_Tests/test_reverse_tunnel_integration.py +++ /dev/null @@ -1,101 +0,0 @@ -# ====================================================== -# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py -# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip -# against a running Engine + Agent (requires live services). -# -# Requirements: -# - Environment variables must be set to point at a live Engine + Agent: -# TUNNEL_TEST_HOST (e.g., https://localhost:5000) -# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device) -# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator) -# - A live Agent must be reachable and allowed to establish the reverse tunnel. -# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable). -# -# API Endpoints (if applicable): -# - POST /api/tunnel/request -# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll) -# ====================================================== - -from __future__ import annotations - -import os -import time - -import pytest -import requests -import socketio - - -HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip() -AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip() -BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip() -VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower() -VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True - -SKIP_MSG = ( - "Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)" -) - - -def _require_env() -> None: - if not HOST or not AGENT_ID or not BEARER: - pytest.skip(SKIP_MSG) - - -def _make_session() -> requests.Session: - sess = requests.Session() - sess.verify = VERIFY - sess.headers.update({"Authorization": f"Bearer {BEARER}"}) - return sess - - -def test_reverse_tunnel_powershell_roundtrip() -> None: - _require_env() - sess = _make_session() - - # 1) Request a tunnel lease - resp = sess.post( - f"{HOST}/api/tunnel/request", - json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "remote-interactive-shell"}, - ) - assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}" - lease = resp.json() - tunnel_id = lease["tunnel_id"] - - # 2) Connect to Socket.IO /tunnel namespace - sio = socketio.Client() - sio.connect( - HOST, - namespaces=["/tunnel"], - headers={"Authorization": f"Bearer {BEARER}"}, - transports=["websocket"], - wait_timeout=10, - ) - - # 3) Join tunnel and open PS channel - join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10) - assert join_resp.get("status") == "ok", f"join failed: {join_resp}" - - open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10) - assert not open_resp.get("error"), f"ps_open failed: {open_resp}" - - # 4) Send a command - send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10) - assert not send_resp.get("error"), f"ps_send failed: {send_resp}" - - # 5) Poll for output - output_text = "" - deadline = time.time() + 15 - while time.time() < deadline: - poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10) - if poll_resp.get("error"): - pytest.fail(f"ps_poll failed: {poll_resp}") - lines = poll_resp.get("output") or [] - output_text += "".join(lines) - if "Hello World" in output_text: - break - time.sleep(0.5) - - sio.disconnect() - - assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}" diff --git a/Data/Engine/config.py b/Data/Engine/config.py index 06259463..9a6caa05 100644 --- a/Data/Engine/config.py +++ b/Data/Engine/config.py @@ -77,17 +77,12 @@ LOG_ROOT = PROJECT_ROOT / "Engine" / "Logs" LOG_FILE_PATH = LOG_ROOT / "engine.log" ERROR_LOG_FILE_PATH = LOG_ROOT / "error.log" API_LOG_FILE_PATH = LOG_ROOT / "api.log" -REVERSE_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log" - -DEFAULT_TUNNEL_FIXED_PORT = 8443 -DEFAULT_TUNNEL_PORT_RANGE = (30000, 40000) -DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS = 3600 -DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS = 3600 -DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS = 20 +VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log" DEFAULT_WIREGUARD_PORT = 30000 DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32" DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24" -DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478) +DEFAULT_WIREGUARD_SHELL_PORT = 47001 +DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT) VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server" @@ -282,18 +277,14 @@ class EngineSettings: error_log_file: str api_log_file: str api_groups: Tuple[str, ...] - reverse_tunnel_fixed_port: int - reverse_tunnel_port_range: Tuple[int, int] - reverse_tunnel_idle_timeout_seconds: int - reverse_tunnel_grace_timeout_seconds: int - reverse_tunnel_heartbeat_seconds: int - reverse_tunnel_log_file: str + vpn_tunnel_log_file: str wireguard_port: int wireguard_engine_virtual_ip: str wireguard_peer_network: str wireguard_server_private_key_path: str wireguard_server_public_key_path: str wireguard_acl_allowlist_windows: Tuple[int, ...] + wireguard_shell_port: int raw: MutableMapping[str, Any] = field(default_factory=dict) def to_flask_config(self) -> MutableMapping[str, Any]: @@ -390,10 +381,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine api_log_file = str(runtime_config.get("API_LOG_FILE") or API_LOG_FILE_PATH) _ensure_parent(Path(api_log_file)) - reverse_tunnel_log_file = str( - runtime_config.get("REVERSE_TUNNEL_LOG_FILE") or REVERSE_TUNNEL_LOG_FILE_PATH + vpn_tunnel_log_file = str( + runtime_config.get("VPN_TUNNEL_LOG_FILE") + or runtime_config.get("WIREGUARD_LOG_FILE") + or os.environ.get("BOREALIS_VPN_TUNNEL_LOG_FILE") + or os.environ.get("BOREALIS_WIREGUARD_LOG_FILE") + or VPN_TUNNEL_LOG_FILE_PATH ) - _ensure_parent(Path(reverse_tunnel_log_file)) + _ensure_parent(Path(vpn_tunnel_log_file)) wireguard_port = _parse_int( runtime_config.get("WIREGUARD_PORT") or os.environ.get("BOREALIS_WIREGUARD_PORT"), @@ -416,6 +411,13 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine or os.environ.get("BOREALIS_WIREGUARD_WINDOWS_ALLOWLIST"), default=DEFAULT_WIREGUARD_ACL_WINDOWS, ) + wireguard_shell_port = _parse_int( + runtime_config.get("WIREGUARD_SHELL_PORT") + or os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT"), + default=DEFAULT_WIREGUARD_SHELL_PORT, + minimum=1, + maximum=65535, + ) wireguard_key_root = Path( runtime_config.get("WIREGUARD_KEY_ROOT") or os.environ.get("BOREALIS_WIREGUARD_KEY_ROOT") @@ -440,35 +442,6 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine "scheduled_jobs", ) - tunnel_fixed_port = _parse_int( - runtime_config.get("TUNNEL_FIXED_PORT") or os.environ.get("BOREALIS_TUNNEL_FIXED_PORT"), - default=DEFAULT_TUNNEL_FIXED_PORT, - minimum=1, - maximum=65535, - ) - tunnel_port_range = _parse_port_range( - runtime_config.get("TUNNEL_PORT_RANGE") or os.environ.get("BOREALIS_TUNNEL_PORT_RANGE"), - default=DEFAULT_TUNNEL_PORT_RANGE, - ) - tunnel_idle_timeout_seconds = _parse_int( - runtime_config.get("TUNNEL_IDLE_TIMEOUT_SECONDS") - or os.environ.get("BOREALIS_TUNNEL_IDLE_TIMEOUT_SECONDS"), - default=DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS, - minimum=60, - ) - tunnel_grace_timeout_seconds = _parse_int( - runtime_config.get("TUNNEL_GRACE_TIMEOUT_SECONDS") - or os.environ.get("BOREALIS_TUNNEL_GRACE_TIMEOUT_SECONDS"), - default=DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS, - minimum=60, - ) - tunnel_heartbeat_seconds = _parse_int( - runtime_config.get("TUNNEL_HEARTBEAT_SECONDS") - or os.environ.get("BOREALIS_TUNNEL_HEARTBEAT_SECONDS"), - default=DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS, - minimum=5, - ) - settings = EngineSettings( database_path=database_path, static_folder=static_folder, @@ -484,18 +457,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine error_log_file=str(error_log_file), api_log_file=str(api_log_file), api_groups=api_groups, - reverse_tunnel_fixed_port=tunnel_fixed_port, - reverse_tunnel_port_range=tunnel_port_range, - reverse_tunnel_idle_timeout_seconds=tunnel_idle_timeout_seconds, - reverse_tunnel_grace_timeout_seconds=tunnel_grace_timeout_seconds, - reverse_tunnel_heartbeat_seconds=tunnel_heartbeat_seconds, - reverse_tunnel_log_file=reverse_tunnel_log_file, + vpn_tunnel_log_file=vpn_tunnel_log_file, wireguard_port=wireguard_port, wireguard_engine_virtual_ip=wireguard_engine_virtual_ip, wireguard_peer_network=wireguard_peer_network, wireguard_server_private_key_path=wireguard_server_private_key_path, wireguard_server_public_key_path=wireguard_server_public_key_path, wireguard_acl_allowlist_windows=wireguard_acl_allowlist_windows, + wireguard_shell_port=wireguard_shell_port, raw=runtime_config, ) return settings diff --git a/Data/Engine/database_migrations.py b/Data/Engine/database_migrations.py index 924764ec..8e628e52 100644 --- a/Data/Engine/database_migrations.py +++ b/Data/Engine/database_migrations.py @@ -2,6 +2,8 @@ # Data\Engine\database_migrations.py # Description: Provides schema evolution helpers for the Engine sqlite # database without importing the legacy ``Modules`` package. +# +# API Endpoints (if applicable): None # ====================================================== """Engine database schema migration helpers.""" @@ -24,6 +26,7 @@ def apply_all(conn: sqlite3.Connection) -> None: _ensure_devices_table(conn) _ensure_device_aux_tables(conn) + _ensure_device_vpn_config_table(conn) _ensure_refresh_token_table(conn) _ensure_install_code_table(conn) _ensure_install_code_persistence_table(conn) @@ -112,6 +115,20 @@ def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None: ) +def _ensure_device_vpn_config_table(conn: sqlite3.Connection) -> None: + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS device_vpn_config ( + agent_id TEXT PRIMARY KEY, + allowed_ports TEXT, + updated_at TEXT, + updated_by TEXT + ) + """ + ) + + def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None: cur = conn.cursor() cur.execute( diff --git a/Data/Engine/server.py b/Data/Engine/server.py index 34b18503..b6d45de3 100644 --- a/Data/Engine/server.py +++ b/Data/Engine/server.py @@ -120,18 +120,14 @@ class EngineContext: config: Mapping[str, Any] api_groups: Sequence[str] api_log_path: str - reverse_tunnel_fixed_port: int - reverse_tunnel_port_range: Tuple[int, int] - reverse_tunnel_idle_timeout_seconds: int - reverse_tunnel_grace_timeout_seconds: int - reverse_tunnel_heartbeat_seconds: int - reverse_tunnel_log_path: str + vpn_tunnel_log_path: str wireguard_port: int wireguard_engine_virtual_ip: str wireguard_peer_network: str wireguard_server_private_key_path: str wireguard_server_public_key_path: str wireguard_acl_allowlist_windows: Tuple[int, ...] + wireguard_shell_port: int wireguard_server_manager: Optional[Any] = None assembly_cache: Optional[Any] = None @@ -151,18 +147,14 @@ def _build_engine_context(settings: EngineSettings, logger: logging.Logger) -> E config=settings.as_dict(), api_groups=settings.api_groups, api_log_path=settings.api_log_file, - reverse_tunnel_fixed_port=settings.reverse_tunnel_fixed_port, - reverse_tunnel_port_range=settings.reverse_tunnel_port_range, - reverse_tunnel_idle_timeout_seconds=settings.reverse_tunnel_idle_timeout_seconds, - reverse_tunnel_grace_timeout_seconds=settings.reverse_tunnel_grace_timeout_seconds, - reverse_tunnel_heartbeat_seconds=settings.reverse_tunnel_heartbeat_seconds, - reverse_tunnel_log_path=settings.reverse_tunnel_log_file, + vpn_tunnel_log_path=settings.vpn_tunnel_log_file, wireguard_port=settings.wireguard_port, wireguard_engine_virtual_ip=settings.wireguard_engine_virtual_ip, wireguard_peer_network=settings.wireguard_peer_network, wireguard_server_private_key_path=settings.wireguard_server_private_key_path, wireguard_server_public_key_path=settings.wireguard_server_public_key_path, wireguard_acl_allowlist_windows=settings.wireguard_acl_allowlist_windows, + wireguard_shell_port=settings.wireguard_shell_port, assembly_cache=None, ) @@ -249,7 +241,7 @@ def create_app(config: Optional[Mapping[str, Any]] = None) -> Tuple[Flask, Socke private_key_path=Path(context.wireguard_server_private_key_path), public_key_path=Path(context.wireguard_server_public_key_path), acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows), - log_path=Path(context.reverse_tunnel_log_path), + log_path=Path(context.vpn_tunnel_log_path), ) context.wireguard_server_manager = WireGuardServerManager(wg_config) except Exception: @@ -325,7 +317,7 @@ def register_engine_api(app: Flask, *, config: Optional[Mapping[str, Any]] = Non private_key_path=Path(context.wireguard_server_private_key_path), public_key_path=Path(context.wireguard_server_public_key_path), acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows), - log_path=Path(context.reverse_tunnel_log_path), + log_path=Path(context.vpn_tunnel_log_path), ) context.wireguard_server_manager = WireGuardServerManager(wg_config) except Exception: diff --git a/Data/Engine/services/API/devices/management.py b/Data/Engine/services/API/devices/management.py index 817ac33b..433c9720 100644 --- a/Data/Engine/services/API/devices/management.py +++ b/Data/Engine/services/API/devices/management.py @@ -9,6 +9,8 @@ # - GET /api/devices/ (Token Authenticated) - Retrieves a single device record by GUID, including summary fields. # - GET /api/device/details/ (Token Authenticated) - Returns full device details keyed by hostname. # - POST /api/device/description/ (Token Authenticated) - Updates the human-readable description for a device. +# - GET /api/device/vpn_config/ (Token Authenticated) - Returns per-device VPN allowed port settings. +# - PUT /api/device/vpn_config/ (Token Authenticated) - Updates per-device VPN allowed port settings. # - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions. # - GET /api/device_list_views/ (Token Authenticated) - Retrieves a specific saved device table view definition. # - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator. @@ -426,6 +428,131 @@ class DeviceManagementService: return None return None + def _parse_ports(self, raw: Any) -> List[int]: + ports: List[int] = [] + if isinstance(raw, str): + parts = [part.strip() for part in raw.split(",") if part.strip()] + elif isinstance(raw, list): + parts = raw + else: + parts = [] + for part in parts: + try: + value = int(part) + except Exception: + continue + if 1 <= value <= 65535: + ports.append(value) + return list(dict.fromkeys(ports)) + + def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]: + ports = list(self.adapters.context.wireguard_acl_allowlist_windows or []) + os_text = (os_name or "").strip().lower() + if os_text and "windows" not in os_text: + baseline = {5900, 3478} + filtered = [p for p in ports if p in baseline] + return filtered or ports + return ports + + def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]: + agent_id = (agent_id or "").strip() + if not agent_id: + return {"error": "agent_id_required"}, 400 + default_ports: List[int] = [] + shell_port = int(self.adapters.context.wireguard_shell_port) + try: + conn = self._db_conn() + cur = conn.cursor() + os_name = "" + try: + cur.execute( + "SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1", + (agent_id,), + ) + row = cur.fetchone() + if row and row[0]: + os_name = str(row[0]) + except Exception: + os_name = "" + default_ports = self._default_vpn_ports(os_name) + cur.execute( + "SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?", + (agent_id,), + ) + row = cur.fetchone() + if not row: + return { + "agent_id": agent_id, + "allowed_ports": default_ports, + "default_ports": default_ports, + "shell_port": shell_port, + "source": "default", + }, 200 + raw_ports = row[0] or "" + ports = [] + try: + ports = json.loads(raw_ports) if raw_ports else [] + except Exception: + ports = [] + return { + "agent_id": agent_id, + "allowed_ports": ports or default_ports, + "default_ports": default_ports, + "shell_port": shell_port, + "updated_at": row[1], + "updated_by": row[2], + "source": "custom" if ports else "default", + }, 200 + except Exception as exc: + self.logger.debug("Failed to load vpn config", exc_info=True) + return {"error": "internal_error"}, 500 + finally: + try: + conn.close() + except Exception: + pass + + def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: + agent_id = (agent_id or "").strip() + if not agent_id: + return {"error": "agent_id_required"}, 400 + ports = self._parse_ports(payload.get("allowed_ports")) + if not ports: + return {"error": "allowed_ports_required"}, 400 + user = self._current_user() or {} + updated_by = user.get("username") or "" + updated_at = datetime.now(timezone.utc).isoformat() + try: + conn = self._db_conn() + cur = conn.cursor() + cur.execute( + """ + INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by) + VALUES (?, ?, ?, ?) + ON CONFLICT(agent_id) DO UPDATE SET + allowed_ports=excluded.allowed_ports, + updated_at=excluded.updated_at, + updated_by=excluded.updated_by + """, + (agent_id, json.dumps(ports), updated_at, updated_by), + ) + conn.commit() + return { + "agent_id": agent_id, + "allowed_ports": ports, + "updated_at": updated_at, + "updated_by": updated_by, + "source": "custom", + }, 200 + except Exception: + self.logger.debug("Failed to save vpn config", exc_info=True) + return {"error": "internal_error"}, 500 + finally: + try: + conn.close() + except Exception: + pass + def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]: if not self._current_user(): return {"error": "unauthorized"}, 401 @@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None: payload, status = service.set_device_description(hostname, description) return jsonify(payload), status + @blueprint.route("/api/device/vpn_config/", methods=["GET", "PUT"]) + def _vpn_config(agent_id: str): + requirement = service._require_login() + if requirement: + payload, status = requirement + return jsonify(payload), status + if request.method == "GET": + payload, status = service.get_vpn_config(agent_id) + else: + body = request.get_json(silent=True) or {} + payload, status = service.set_vpn_config(agent_id, body) + return jsonify(payload), status + @blueprint.route("/api/device_list_views", methods=["GET"]) def _list_views(): requirement = service._require_login() diff --git a/Data/Engine/services/API/devices/tunnel.py b/Data/Engine/services/API/devices/tunnel.py index c8127ec8..c9f51861 100644 --- a/Data/Engine/services/API/devices/tunnel.py +++ b/Data/Engine/services/API/devices/tunnel.py @@ -1,12 +1,14 @@ # ====================================================== # Data\Engine\services\API\devices\tunnel.py -# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired). +# Description: WireGuard VPN tunnel API (connect/status/disconnect). # # API Endpoints (if applicable): -# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol. +# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent. +# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent. +# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent. # ====================================================== -"""Reverse tunnel negotiation API (Engine side).""" +"""WireGuard VPN tunnel API (Engine side).""" from __future__ import annotations import os @@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple from flask import Blueprint, jsonify, request, session from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer -from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService +from ...VPN import VpnTunnelService if False: # pragma: no cover - import cycle hint for type checkers from .. import EngineServiceAdapters def _current_user(app) -> Optional[Dict[str, str]]: - """Resolve operator identity from session or signed token.""" - username = session.get("username") role = session.get("role") or "User" if username: @@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]: return None -def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService: - service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None) +def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService: + service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None) if service is None: - service = ReverseTunnelService( - adapters.context, - signer=getattr(adapters, "script_signer", None), + manager = getattr(adapters.context, "wireguard_server_manager", None) + if manager is None: + raise RuntimeError("wireguard_manager_unavailable") + service = VpnTunnelService( + context=adapters.context, + wireguard_manager=manager, db_conn_factory=adapters.db_conn_factory, socketio=getattr(adapters.context, "socketio", None), + service_log=adapters.service_log, + signer=getattr(adapters, "script_signer", None), ) - service.start() - setattr(adapters, "_reverse_tunnel_service", service) - setattr(adapters.context, "reverse_tunnel_service", service) + setattr(adapters, "_vpn_tunnel_service", service) + setattr(adapters.context, "vpn_tunnel_service", service) return service @@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str: def register_tunnel(app, adapters: "EngineServiceAdapters") -> None: - """Register reverse tunnel negotiation endpoints.""" + blueprint = Blueprint("vpn_tunnel", __name__) + logger = adapters.context.logger.getChild("vpn_tunnel.api") - blueprint = Blueprint("reverse_tunnel", __name__) - service_log = adapters.service_log - logger = adapters.context.logger.getChild("tunnel.api") - - @blueprint.route("/api/tunnel/request", methods=["POST"]) - def request_tunnel(): + @blueprint.route("/api/tunnel/connect", methods=["POST"]) + def connect_tunnel(): requirement = _require_login(app) if requirement: payload, status = requirement @@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None: body = request.get_json(silent=True) or {} agent_id = _normalize_text(body.get("agent_id")) - protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps" - domain = _normalize_text(body.get("domain") or protocol).lower() or protocol - if protocol == "ps" and domain == "ps": - domain = "remote-interactive-shell" - if not agent_id: return jsonify({"error": "agent_id_required"}), 400 - tunnel_service = _get_tunnel_service(adapters) try: - lease = tunnel_service.request_lease( - agent_id=agent_id, - protocol=protocol, - domain=domain, - operator_id=operator_id, - ) + tunnel_service = _get_tunnel_service(adapters) + payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id) except RuntimeError as exc: - message = str(exc) - if message.startswith("domain_limit:"): - domain_name = message.split(":", 1)[-1] if ":" in message else domain - return jsonify({"error": "domain_limit", "domain": domain_name}), 409 - if message == "port_pool_exhausted": - return jsonify({"error": "port_pool_exhausted"}), 503 - logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message) - return jsonify({"error": "lease_allocation_failed"}), 500 + logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc) + return jsonify({"error": "connect_failed"}), 500 - summary = tunnel_service.lease_summary(lease) - summary["fixed_port"] = tunnel_service.fixed_port - summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds + return jsonify(payload), 200 - service_log( - "reverse_tunnel", - f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}", - ) - return jsonify(summary), 200 - - @blueprint.route("/api/tunnel/", methods=["DELETE"]) - def stop_tunnel(tunnel_id: str): + @blueprint.route("/api/tunnel/status", methods=["GET"]) + def tunnel_status(): requirement = _require_login(app) if requirement: payload, status = requirement return jsonify(payload), status - tunnel_id_norm = _normalize_text(tunnel_id) - if not tunnel_id_norm: - return jsonify({"error": "tunnel_id_required"}), 400 + agent_id = _normalize_text(request.args.get("agent_id") or "") + if not agent_id: + return jsonify({"error": "agent_id_required"}), 400 + + tunnel_service = _get_tunnel_service(adapters) + payload = tunnel_service.status(agent_id) + if not payload: + return jsonify({"status": "down", "agent_id": agent_id}), 200 + payload["status"] = "up" + bump = _normalize_text(request.args.get("bump") or "") + if bump: + tunnel_service.bump_activity(agent_id) + return jsonify(payload), 200 + + @blueprint.route("/api/tunnel/connect/status", methods=["GET"]) + def tunnel_connect_status(): + return tunnel_status() + + @blueprint.route("/api/tunnel/disconnect", methods=["DELETE"]) + def disconnect_tunnel(): + requirement = _require_login(app) + if requirement: + payload, status = requirement + return jsonify(payload), status body = request.get_json(silent=True) or {} + agent_id = _normalize_text(body.get("agent_id")) + tunnel_id = _normalize_text(body.get("tunnel_id")) reason = _normalize_text(body.get("reason") or "operator_stop") tunnel_service = _get_tunnel_service(adapters) stopped = False - try: - stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason) - except Exception as exc: # pragma: no cover - defensive guard - logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True) + if tunnel_id: + stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason) + elif agent_id: + stopped = tunnel_service.disconnect(agent_id, reason=reason) + else: + return jsonify({"error": "agent_id_required"}), 400 + if not stopped: return jsonify({"error": "not_found"}), 404 - service_log( - "reverse_tunnel", - f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}", - ) - return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200 + return jsonify({"status": "stopped", "reason": reason}), 200 app.register_blueprint(blueprint) diff --git a/Data/Engine/services/VPN/__init__.py b/Data/Engine/services/VPN/__init__.py index 51347aaa..a509aa87 100644 --- a/Data/Engine/services/VPN/__init__.py +++ b/Data/Engine/services/VPN/__init__.py @@ -8,4 +8,4 @@ """VPN service helpers for the Engine runtime.""" from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401 - +from .vpn_tunnel_service import VpnTunnelService # noqa: F401 diff --git a/Data/Engine/services/VPN/vpn_tunnel_service.py b/Data/Engine/services/VPN/vpn_tunnel_service.py new file mode 100644 index 00000000..11d03d08 --- /dev/null +++ b/Data/Engine/services/VPN/vpn_tunnel_service.py @@ -0,0 +1,473 @@ +# ====================================================== +# Data\Engine\services\VPN\vpn_tunnel_service.py +# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling). +# +# API Endpoints (if applicable): None +# ====================================================== + +"""WireGuard tunnel orchestration helpers for the Engine runtime.""" + +from __future__ import annotations + +import base64 +import ipaddress +import json +import threading +import time +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +from .wireguard_server import WireGuardServerManager + + +@dataclass +class VpnSession: + tunnel_id: str + agent_id: str + virtual_ip: str + token: Dict[str, Any] + client_public_key: str + client_private_key: str + allowed_ports: Tuple[int, ...] + created_at: float + expires_at: float + last_activity: float + operator_ids: set[str] = field(default_factory=set) + firewall_rules: List[str] = field(default_factory=list) + activity_id: Optional[int] = None + hostname: Optional[str] = None + + +class VpnTunnelService: + def __init__( + self, + *, + context: Any, + wireguard_manager: WireGuardServerManager, + db_conn_factory, + socketio, + service_log, + signer: Optional[Any] = None, + idle_seconds: int = 900, + ) -> None: + self.context = context + self.wg = wireguard_manager + self.db_conn_factory = db_conn_factory + self.socketio = socketio + self.service_log = service_log + self.signer = signer + self.logger = context.logger.getChild("vpn_tunnel") + self.activity_logger = self.wg.logger.getChild("device_activity") + self.idle_seconds = max(60, int(idle_seconds)) + self._lock = threading.Lock() + self._sessions_by_agent: Dict[str, VpnSession] = {} + self._sessions_by_tunnel: Dict[str, VpnSession] = {} + self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip) + self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False) + self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True) + self._idle_thread.start() + + def _idle_loop(self) -> None: + while True: + time.sleep(10) + now = time.time() + expired: List[VpnSession] = [] + with self._lock: + for session in list(self._sessions_by_agent.values()): + if session.last_activity + self.idle_seconds <= now: + expired.append(session) + for session in expired: + self.disconnect(session.agent_id, reason="idle_timeout") + + def _allocate_virtual_ip(self, agent_id: str) -> str: + existing = self._sessions_by_agent.get(agent_id) + if existing: + return existing.virtual_ip + + used = {s.virtual_ip for s in self._sessions_by_agent.values()} + for host in self._peer_network.hosts(): + if host == self._engine_ip.ip: + continue + candidate = f"{host}/32" + if candidate not in used: + return candidate + raise RuntimeError("vpn_ip_pool_exhausted") + + def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]: + default = tuple(self.context.wireguard_acl_allowlist_windows or ()) + try: + conn = self.db_conn_factory() + cur = conn.cursor() + try: + cur.execute( + "SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1", + (agent_id,), + ) + row = cur.fetchone() + os_name = str(row[0]).lower() if row and row[0] else "" + except Exception: + os_name = "" + if os_name and "windows" not in os_name: + baseline = {5900, 3478} + filtered = [p for p in default if p in baseline] + if filtered: + default = tuple(filtered) + cur.execute( + "SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?", + (agent_id,), + ) + row = cur.fetchone() + if not row: + return default + raw = row[0] or "" + ports = json.loads(raw) if raw else [] + ports = [int(p) for p in ports if isinstance(p, (int, float, str))] + ports = [p for p in ports if 1 <= p <= 65535] + return tuple(dict.fromkeys(ports)) or default + except Exception: + return default + finally: + try: + conn.close() + except Exception: + pass + + def _generate_client_keys(self) -> Tuple[str, str]: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import x25519 + + key = x25519.X25519PrivateKey.generate() + priv = base64.b64encode( + key.private_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PrivateFormat.Raw, + encryption_algorithm=serialization.NoEncryption(), + ) + ).decode("ascii").strip() + pub = base64.b64encode( + key.public_key().public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + ).decode("ascii").strip() + return priv, pub + + def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]: + payload = { + "agent_id": agent_id, + "tunnel_id": tunnel_id, + "port": self.context.wireguard_port, + "expires_at": expires_at, + "issued_at": time.time(), + } + if not self.signer: + return dict(payload) + + token = dict(payload) + try: + payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + signature = self.signer.sign(payload_bytes) + token["signature"] = base64.b64encode(signature).decode("ascii") + if hasattr(self.signer, "public_base64_spki"): + token["signing_key"] = self.signer.public_base64_spki() + token["sig_alg"] = "ed25519" + except Exception: + self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True) + return token + + def _refresh_listener(self) -> None: + peers: List[Mapping[str, object]] = [] + for session in self._sessions_by_agent.values(): + peer = self.wg.build_peer_profile( + session.agent_id, + session.virtual_ip, + allowed_ports=session.allowed_ports, + ) + peer = dict(peer) + peer["public_key"] = session.client_public_key + peers.append(peer) + if not peers: + self.wg.stop_listener() + return + self.wg.start_listener(peers) + + def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]: + now = time.time() + with self._lock: + existing = self._sessions_by_agent.get(agent_id) + if existing: + if operator_id: + existing.operator_ids.add(operator_id) + existing.last_activity = now + return self._session_payload(existing) + + tunnel_id = uuid.uuid4().hex + virtual_ip = self._allocate_virtual_ip(agent_id) + allowed_ports = self._load_allowed_ports(agent_id) + client_private, client_public = self._generate_client_keys() + token = self._issue_token(agent_id, tunnel_id, now + 300) + self.wg.require_orchestration_token(token) + + session = VpnSession( + tunnel_id=tunnel_id, + agent_id=agent_id, + virtual_ip=virtual_ip, + token=token, + client_public_key=client_public, + client_private_key=client_private, + allowed_ports=allowed_ports, + created_at=now, + expires_at=now + 300, + last_activity=now, + ) + if operator_id: + session.operator_ids.add(operator_id) + self._sessions_by_agent[agent_id] = session + self._sessions_by_tunnel[tunnel_id] = session + + try: + self._refresh_listener() + + peer = self.wg.build_peer_profile( + agent_id, + virtual_ip, + allowed_ports=allowed_ports, + ) + rule_names = self.wg.apply_firewall_rules(peer) + session.firewall_rules = rule_names + except Exception: + with self._lock: + self._sessions_by_agent.pop(agent_id, None) + self._sessions_by_tunnel.pop(tunnel_id, None) + try: + self._refresh_listener() + except Exception: + self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True) + raise + + payload = self._session_payload(session) + self._emit_start(payload) + self._log_device_activity(session, event="start") + return payload + + def status(self, agent_id: str) -> Optional[Mapping[str, Any]]: + with self._lock: + session = self._sessions_by_agent.get(agent_id) + if not session: + return None + return self._session_payload(session, include_token=False) + + def bump_activity(self, agent_id: str) -> None: + with self._lock: + session = self._sessions_by_agent.get(agent_id) + if not session: + return + session.last_activity = time.time() + try: + if self.socketio: + self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/") + except Exception: + self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True) + + def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool: + with self._lock: + session = self._sessions_by_agent.pop(agent_id, None) + if not session: + return False + self._sessions_by_tunnel.pop(session.tunnel_id, None) + + try: + self.wg.remove_firewall_rules(session.firewall_rules) + except Exception: + self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True) + + self._refresh_listener() + self._emit_stop(session, reason) + self._log_device_activity(session, event="stop", reason=reason) + return True + + def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool: + with self._lock: + session = self._sessions_by_tunnel.get(tunnel_id) + if not session: + return False + return self.disconnect(session.agent_id, reason=reason) + + def _emit_start(self, payload: Mapping[str, Any]) -> None: + if not self.socketio: + return + try: + self.socketio.emit("vpn_tunnel_start", payload, namespace="/") + except Exception: + self.logger.debug("vpn_tunnel_start emit failed", exc_info=True) + + def _emit_stop(self, session: VpnSession, reason: str) -> None: + if not self.socketio: + return + try: + self.socketio.emit( + "vpn_tunnel_stop", + {"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason}, + namespace="/", + ) + except Exception: + self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True) + + def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None: + if self.db_conn_factory is None: + self.activity_logger.info( + "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s", + event, + session.agent_id, + session.tunnel_id, + ",".join(sorted(filter(None, session.operator_ids))) or "-", + reason or "-", + ) + return + + conn = None + try: + conn = self.db_conn_factory() + cur = conn.cursor() + hostname = session.hostname + if not hostname: + try: + cur.execute( + "SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1", + (session.agent_id,), + ) + row = cur.fetchone() + if row and row[0]: + hostname = str(row[0]).strip() + session.hostname = hostname + except Exception: + hostname = None + + if not hostname: + self.activity_logger.info( + "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown", + event, + session.agent_id, + session.tunnel_id, + ",".join(sorted(filter(None, session.operator_ids))) or "-", + reason or "-", + ) + return + + now_ts = int(time.time()) + script_name = "Reverse VPN Tunnel (WireGuard)" + + if event == "start": + cur.execute( + """ + INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr) + VALUES(?,?,?,?,?,?,?,?) + """, + ( + hostname, + session.tunnel_id, + script_name, + "vpn_tunnel", + now_ts, + "Running", + "", + "", + ), + ) + session.activity_id = cur.lastrowid + conn.commit() + if self.socketio: + try: + self.socketio.emit( + "device_activity_changed", + { + "hostname": hostname, + "activity_id": session.activity_id, + "change": "created", + "source": "vpn_tunnel", + }, + ) + except Exception: + pass + self.activity_logger.info( + "device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s", + hostname, + session.agent_id, + session.tunnel_id, + ",".join(sorted(filter(None, session.operator_ids))) or "-", + session.activity_id or "-", + ) + return + + if session.activity_id: + status = "Completed" if event == "stop" else "Closed" + cur.execute( + """ + UPDATE activity_history + SET status=?, + stderr=COALESCE(stderr, '') || ? + WHERE id=? + """, + ( + status, + f"\nreason: {reason}" if reason else "", + session.activity_id, + ), + ) + conn.commit() + if self.socketio: + try: + self.socketio.emit( + "device_activity_changed", + { + "hostname": hostname, + "activity_id": session.activity_id, + "change": "updated", + "source": "vpn_tunnel", + }, + ) + except Exception: + pass + + self.activity_logger.info( + "device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s", + event, + hostname, + session.agent_id, + session.tunnel_id, + ",".join(sorted(filter(None, session.operator_ids))) or "-", + reason or "-", + session.activity_id or "-", + ) + except Exception: + self.activity_logger.debug( + "device_activity logging failed for tunnel_id=%s", + session.tunnel_id, + exc_info=True, + ) + finally: + if conn is not None: + try: + conn.close() + except Exception: + pass + + def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]: + payload: Dict[str, Any] = { + "tunnel_id": session.tunnel_id, + "agent_id": session.agent_id, + "virtual_ip": session.virtual_ip, + "engine_virtual_ip": str(self._engine_ip.ip), + "allowed_ips": f"{self._engine_ip.ip}/32", + "endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}", + "server_public_key": self.wg.server_public_key, + "client_public_key": session.client_public_key, + "client_private_key": session.client_private_key, + "idle_seconds": self.idle_seconds, + "allowed_ports": list(session.allowed_ports), + "connected_operators": len([o for o in session.operator_ids if o]), + } + if include_token: + payload["token"] = session.token + return payload diff --git a/Data/Engine/services/VPN/wireguard_server.py b/Data/Engine/services/VPN/wireguard_server.py index 1eadb967..ebfc57bc 100644 --- a/Data/Engine/services/VPN/wireguard_server.py +++ b/Data/Engine/services/VPN/wireguard_server.py @@ -70,7 +70,7 @@ class WireGuardServerManager: self.logger = _build_logger(config.log_path) self._ensure_cert_dir() self.server_private_key, self.server_public_key = self._ensure_server_keys() - self._service_name = "BorealisWireGuard" + self._service_name = "borealis-wg" self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine" def _ensure_cert_dir(self) -> None: @@ -157,7 +157,7 @@ class WireGuardServerManager: if not token: raise ValueError("Missing orchestration token for WireGuard peer") - required_fields = ("agent_id", "tunnel_id", "expires_at") + required_fields = ("agent_id", "tunnel_id", "expires_at", "port") missing = [field for field in required_fields if field not in token or token[field] in (None, "")] if missing: raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}") @@ -167,6 +167,13 @@ class WireGuardServerManager: except Exception: raise ValueError("Invalid orchestration token expiry") + try: + port = int(token["port"]) + except Exception: + raise ValueError("Invalid orchestration token port") + if port != int(self.config.port): + raise ValueError("Orchestration token port mismatch") + now = time.time() if expires_at <= now: raise ValueError("Orchestration token expired") @@ -253,12 +260,14 @@ class WireGuardServerManager: "host_only": True, } - def apply_firewall_rules(self, peer: Mapping[str, object]) -> None: + def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]: """Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh).""" rules = self.build_firewall_rules(peer) + rule_names: List[str] = [] for idx, rule in enumerate(rules): name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}" + protocol = str(rule.get("protocol") or "TCP").upper() args = [ "netsh", "advfirewall", @@ -269,7 +278,7 @@ class WireGuardServerManager: "dir=out", "action=allow", f"remoteip={rule.get('remote_address','')}", - f"protocol=TCP", + f"protocol={protocol}", f"localport={rule.get('local_port','')}", ] code, out, err = self._run_command(args) @@ -277,6 +286,19 @@ class WireGuardServerManager: self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err) else: self.logger.info("Applied firewall rule %s", name) + rule_names.append(name) + return rule_names + + def remove_firewall_rules(self, rule_names: Sequence[str]) -> None: + for name in rule_names: + if not name: + continue + args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"] + code, out, err = self._run_command(args) + if code != 0: + self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err) + else: + self.logger.info("Removed firewall rule %s", name) def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None: """Render a temporary WireGuard config and start the service.""" @@ -291,6 +313,9 @@ class WireGuardServerManager: config_path.write_text(rendered, encoding="utf-8") self.logger.info("Rendered WireGuard config to %s", config_path) + # Ensure old service is removed before re-installing. + self.stop_listener() + args = ["wireguard.exe", "/installtunnelservice", str(config_path)] code, out, err = self._run_command(args) if code != 0: @@ -301,7 +326,7 @@ class WireGuardServerManager: def stop_listener(self) -> None: """Stop and remove the WireGuard tunnel service.""" - args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"] + args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name] code, out, err = self._run_command(args) if code != 0: self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err) @@ -323,15 +348,17 @@ class WireGuardServerManager: port_list = [] for port in port_list: - rules.append( - { - "direction": "outbound", - "remote_address": ip, - "local_port": port, - "action": "allow", - "description": f"WireGuard engine->agent allow port {port}", - } - ) + for protocol in ("TCP", "UDP"): + rules.append( + { + "direction": "outbound", + "remote_address": ip, + "local_port": port, + "protocol": protocol, + "action": "allow", + "description": f"WireGuard engine->agent allow port {port}/{protocol}", + } + ) self.logger.info( "Prepared firewall rule plan for agent=%s rules=%s", diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/__init__.py deleted file mode 100644 index 9caba325..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Namespace package for reverse tunnel domain handlers (Engine side).""" - -__all__ = ["remote_interactive_shell", "remote_management", "remote_video"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py deleted file mode 100644 index 5823d8d6..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Bash.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Placeholder Bash channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class BashChannelServer: - """Stub Bash handler until the agent-side channel is implemented.""" - - protocol_name = "bash" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - # No-op placeholder; output collection for future Bash support. - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - # Placeholder: no agent-side Bash yet. - self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - # Placeholder: not implemented. - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["BashChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py deleted file mode 100644 index b231a87e..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain).""" -from __future__ import annotations - -import json -from collections import deque -from typing import Any, Deque, Dict, List, Optional - -# Mirror framing constants to avoid circular imports. -MSG_CHANNEL_OPEN = 0x03 -MSG_CHANNEL_ACK = 0x04 -MSG_DATA = 0x05 -MSG_CONTROL = 0x09 -MSG_CLOSE = 0x08 -CLOSE_OK = 0 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AGENT_SHUTDOWN = 6 - - -class PowershellChannelServer: - """Coordinate PowerShell channel frames over a TunnelBridge.""" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output: Deque[str] = deque() - self._close_reason: Optional[str] = None - self._close_code: Optional[int] = None - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - # ------------------------------------------------------------------ Agent frame handling - def handle_agent_frame(self, frame) -> None: - if frame.channel_id != self.channel_id: - return - if frame.msg_type == MSG_CHANNEL_ACK: - self._ack_received = True - self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id) - return - if frame.msg_type == MSG_DATA: - try: - text = frame.payload.decode("utf-8", errors="replace") - except Exception: - text = "" - if text: - self._append_output(text) - return - if frame.msg_type == MSG_CLOSE: - try: - payload = json.loads(frame.payload.decode("utf-8")) - except Exception: - payload = {} - self._closed = True - self._close_code = payload.get("code") if isinstance(payload, dict) else None - self._close_reason = payload.get("reason") if isinstance(payload, dict) else None - self.logger.info( - "ps channel closed tunnel_id=%s code=%s reason=%s", - self.bridge.lease.tunnel_id, - self._close_code, - self._close_reason or "-", - ) - - # ------------------------------------------------------------------ Operator actions - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - if self._open_sent: - return - payload = json.dumps( - {"protocol": "ps", "metadata": {"cols": cols, "rows": rows}}, - separators=(",", ":"), - ).encode("utf-8") - frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload) - self.bridge.operator_to_agent(frame) - self._open_sent = True - self.logger.info( - "ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - if self._closed: - return - payload = data.encode("utf-8", errors="replace") - frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload) - self.bridge.operator_to_agent(frame) - - def send_resize(self, cols: int, rows: int) -> None: - if self._closed: - return - payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8") - frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload) - self.bridge.operator_to_agent(frame) - - def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - frame = self._close_frame_fn(self.channel_id, code, reason) - else: - frame = self._frame_cls( - msg_type=MSG_CLOSE, - channel_id=self.channel_id, - payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"), - ) - self.bridge.operator_to_agent(frame) - - # ------------------------------------------------------------------ Output polling - def drain_output(self) -> List[str]: - items: List[str] = [] - while self._output: - items.append(self._output.popleft()) - return items - - def _append_output(self, text: str) -> None: - self._output.append(text) - # Cap buffer to avoid unbounded memory growth. - while len(self._output) > 500: - self._output.popleft() - - # ------------------------------------------------------------------ Status helpers - def status(self) -> Dict[str, Any]: - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": self._close_reason, - "close_code": self._close_code, - } - - -__all__ = ["PowershellChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py deleted file mode 100644 index 853fd5fd..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/Protocols/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Protocol handlers for remote interactive shell tunnels (Engine side).""" - -from .Powershell import PowershellChannelServer -from .Bash import BashChannelServer - -__all__ = ["PowershellChannelServer", "BashChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/__init__.py deleted file mode 100644 index f8ee48a4..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_interactive_shell/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Domain handlers for remote interactive shells (PowerShell/Bash).""" - -__all__ = ["Protocols"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/SSH.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/SSH.py deleted file mode 100644 index 6800b4c1..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/SSH.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Placeholder SSH channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class SSHChannelServer: - protocol_name = "ssh" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["SSHChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/WinRM.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/WinRM.py deleted file mode 100644 index 5c0c7c63..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/WinRM.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Placeholder WinRM channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class WinRMChannelServer: - protocol_name = "winrm" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["WinRMChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/__init__.py deleted file mode 100644 index e5653ef1..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/Protocols/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Protocol handlers for remote management tunnels (Engine side).""" - -from .SSH import SSHChannelServer -from .WinRM import WinRMChannelServer - -__all__ = ["SSHChannelServer", "WinRMChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/__init__.py deleted file mode 100644 index 12632a2e..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_management/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Domain handlers for remote management tunnels (SSH/WinRM).""" - -__all__ = ["Protocols"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/RDP.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/RDP.py deleted file mode 100644 index b154de12..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/RDP.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Placeholder RDP channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class RDPChannelServer: - protocol_name = "rdp" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["RDPChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/VNC.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/VNC.py deleted file mode 100644 index 33d980ea..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/VNC.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Placeholder VNC channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class VNCChannelServer: - protocol_name = "vnc" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["VNCChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/WebRTC.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/WebRTC.py deleted file mode 100644 index 24f440ca..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/WebRTC.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Placeholder WebRTC channel server (Engine side).""" -from __future__ import annotations - -from collections import deque - - -class WebRTCChannelServer: - protocol_name = "webrtc" - - def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None): - self.bridge = bridge - self.service = service - self.channel_id = channel_id - self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}") - self._open_sent = False - self._ack_received = False - self._closed = False - self._output = deque() - self._frame_cls = frame_cls - self._close_frame_fn = close_frame_fn - - def handle_agent_frame(self, frame) -> None: - try: - if frame.msg_type == 0x04: # MSG_CHANNEL_ACK - self._ack_received = True - except Exception: - return - - def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: - self._open_sent = True - self.logger.info( - "webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", - self.bridge.lease.tunnel_id, - self.channel_id, - cols, - rows, - ) - - def send_input(self, data: str) -> None: - self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id) - - def send_resize(self, cols: int, rows: int) -> None: - return - - def close(self, code: int = 6, reason: str = "operator_close") -> None: - if self._closed: - return - self._closed = True - if callable(self._close_frame_fn): - try: - frame = self._close_frame_fn(self.channel_id, code, reason) - self.bridge.operator_to_agent(frame) - except Exception: - pass - - def drain_output(self): - items = [] - while self._output: - items.append(self._output.popleft()) - return items - - def status(self): - return { - "channel_id": self.channel_id, - "open_sent": self._open_sent, - "ack": self._ack_received, - "closed": self._closed, - "close_reason": None, - "close_code": None, - } - - -__all__ = ["WebRTCChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/__init__.py deleted file mode 100644 index 8029c5d3..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/Protocols/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Protocol handlers for remote video tunnels (Engine side).""" - -from .WebRTC import WebRTCChannelServer -from .RDP import RDPChannelServer -from .VNC import VNCChannelServer - -__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"] diff --git a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/__init__.py b/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/__init__.py deleted file mode 100644 index e57ce485..00000000 --- a/Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/remote_video/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC).""" - -__all__ = ["Protocols"] diff --git a/Data/Engine/services/WebSocket/Agent/__init__.py b/Data/Engine/services/WebSocket/Agent/__init__.py deleted file mode 100644 index cc8c3730..00000000 --- a/Data/Engine/services/WebSocket/Agent/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# ====================================================== -# Data\Engine\services\WebSocket\Agent\__init__.py -# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding). -# -# API Endpoints (if applicable): None -# ====================================================== - -"""Agent-facing WebSocket services for the Engine runtime.""" - -__all__ = [] diff --git a/Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py b/Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py deleted file mode 100644 index 3e9df698..00000000 --- a/Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py +++ /dev/null @@ -1,1361 +0,0 @@ -# ====================================================== -# Data\Engine\services\WebSocket\Agent\reverse_tunnel_orchestrator.py -# Description: Async reverse tunnel scaffolding (Engine side) providing lease management, domain limits, and placeholders for WebSocket listeners. -# -# API Endpoints (if applicable): None -# ====================================================== - -"""Engine-side reverse tunnel scaffolding. - -This module lays down the lease manager and configuration surface for the -Agent reverse tunnel without wiring listeners into the runtime. It preserves -the existing Socket.IO control plane while preparing async WebSocket -infrastructure to serve per-agent reverse tunnels. -""" -from __future__ import annotations - -import asyncio -import base64 -import json -import logging -import secrets -import ssl -import struct -import time -from dataclasses import dataclass, field -from logging.handlers import TimedRotatingFileHandler -from pathlib import Path -from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple -from collections import deque -from threading import Thread - -from .Reverse_Tunnels.remote_interactive_shell.Protocols.Powershell import PowershellChannelServer -from .Reverse_Tunnels.remote_interactive_shell.Protocols.Bash import BashChannelServer -from .Reverse_Tunnels.remote_management.Protocols.SSH import SSHChannelServer -from .Reverse_Tunnels.remote_management.Protocols.WinRM import WinRMChannelServer -from .Reverse_Tunnels.remote_video.Protocols.VNC import VNCChannelServer -from .Reverse_Tunnels.remote_video.Protocols.RDP import RDPChannelServer -from .Reverse_Tunnels.remote_video.Protocols.WebRTC import WebRTCChannelServer - -try: # websockets is added to engine requirements - import websockets - from websockets.server import serve as ws_serve -except Exception: # pragma: no cover - dependency resolved at runtime - websockets = None - ws_serve = None - -from ....server import EngineContext - -TunnelState = str - - -def _utc_ts() -> float: - return time.time() - - -def _generate_tunnel_id() -> str: - # UUID4-like, but defer to secrets for a short scaffold without adding deps. - hex_blob = secrets.token_hex(16) - return f"{hex_blob[0:8]}-{hex_blob[8:12]}-{hex_blob[12:16]}-{hex_blob[16:20]}-{hex_blob[20:32]}" - - -class FrameDecodeError(Exception): - """Raised when an incoming frame is malformed.""" - - -class FrameValidationError(Exception): - """Raised when a frame fails validation.""" - - -# Message types -MSG_CONNECT = 0x01 -MSG_CONNECT_ACK = 0x02 -MSG_CHANNEL_OPEN = 0x03 -MSG_CHANNEL_ACK = 0x04 -MSG_DATA = 0x05 -MSG_WINDOW_UPDATE = 0x06 -MSG_HEARTBEAT = 0x07 -MSG_CLOSE = 0x08 -MSG_CONTROL = 0x09 - -# Close codes -CLOSE_OK = 0 -CLOSE_IDLE_TIMEOUT = 1 -CLOSE_GRACE_EXPIRED = 2 -CLOSE_PROTOCOL_ERROR = 3 -CLOSE_AUTH_FAILED = 4 -CLOSE_SERVER_SHUTDOWN = 5 -CLOSE_AGENT_SHUTDOWN = 6 -CLOSE_DOMAIN_LIMIT = 7 -CLOSE_UNEXPECTED_DISCONNECT = 8 - -FRAME_HEADER_STRUCT = struct.Struct(" bytes: - payload_len = len(self.payload or b"") - header = FRAME_HEADER_STRUCT.pack( - self.version, - self.msg_type, - self.flags, - self.reserved, - int(self.channel_id), - payload_len, - ) - return header + (self.payload or b"") - - -def decode_frame(buffer: bytes) -> TunnelFrame: - """Decode a single tunnel frame from bytes.""" - - if len(buffer) < FRAME_HEADER_STRUCT.size: - raise FrameDecodeError("frame_too_small") - try: - version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(buffer, 0) - except struct.error as exc: - raise FrameDecodeError(f"frame_unpack_error:{exc}") from exc - - if version != FRAME_VERSION: - raise FrameValidationError(f"unsupported_version:{version}") - if length < 0: - raise FrameValidationError("invalid_length") - expected_total = FRAME_HEADER_STRUCT.size + length - if len(buffer) < expected_total: - raise FrameDecodeError("incomplete_frame") - payload = buffer[FRAME_HEADER_STRUCT.size : expected_total] - if len(payload) != length: - raise FrameValidationError("length_mismatch") - - return TunnelFrame( - version=version, - msg_type=msg_type, - flags=flags, - reserved=reserved, - channel_id=channel_id, - payload=payload, - ) - - -def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame: - """Build a heartbeat ping/pong frame.""" - - flags = 0x1 if is_ack else 0x0 - return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"") - - -def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame: - payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8") - return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload) - - -def _build_tunnel_logger(log_path: Path) -> logging.Logger: - """Create a dedicated reverse tunnel logger with daily rotation.""" - - try: - log_path.parent.mkdir(parents=True, exist_ok=True) - except Exception: - pass - - logger = logging.getLogger("borealis.engine.reverse_tunnel") - if not logger.handlers: - formatter = logging.Formatter("%(asctime)s-%(name)s-%(levelname)s: %(message)s") - handler = TimedRotatingFileHandler(str(log_path), when="midnight", backupCount=0, encoding="utf-8") - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(logging.INFO) - logger.propagate = False - return logger - - -@dataclass -class TunnelLease: - tunnel_id: str - agent_id: str - domain: str - protocol: str - operator_id: Optional[str] - assigned_port: int - token: Optional[str] = None - hostname: Optional[str] = None - activity_id: Optional[int] = None - created_at: float = field(default_factory=_utc_ts) - expires_at: Optional[float] = None - idle_timeout_seconds: int = 3600 - grace_timeout_seconds: int = 3600 - state: TunnelState = "pending" - last_activity_ts: float = field(default_factory=_utc_ts) - agent_connected_at: Optional[float] = None - agent_disconnected_at: Optional[float] = None - - def mark_active(self) -> None: - self.state = "active" - self.agent_connected_at = _utc_ts() - self.last_activity_ts = self.agent_connected_at - - def mark_disconnected(self) -> None: - self.agent_disconnected_at = _utc_ts() - self.last_activity_ts = self.agent_disconnected_at - - def touch(self) -> None: - self.last_activity_ts = _utc_ts() - - def mark_closing(self) -> None: - self.state = "closing" - - def mark_expired(self) -> None: - self.state = "expired" - - def to_summary(self) -> Dict[str, object]: - return { - "tunnel_id": self.tunnel_id, - "agent_id": self.agent_id, - "domain": self.domain, - "protocol": self.protocol, - "operator_id": self.operator_id, - "assigned_port": self.assigned_port, - "state": self.state, - "created_at": self.created_at, - "expires_at": self.expires_at, - "idle_timeout_seconds": self.idle_timeout_seconds, - "grace_timeout_seconds": self.grace_timeout_seconds, - "last_activity_ts": self.last_activity_ts, - "agent_connected_at": self.agent_connected_at, - "agent_disconnected_at": self.agent_disconnected_at, - } - - -class DomainPolicy: - """Enforce per-domain concurrency and defaults.""" - - DEFAULT_LIMITS = { - # New domain lanes - "remote-interactive-shell": 2, - "remote-management": 1, - "remote-video": 2, - # Protocol-specific fallbacks (for backward compatibility / legacy callers) - "ps": 2, - "rdp": 1, - "vnc": 1, - "webrtc": 2, - "ssh": None, # Unlimited - "winrm": None, # Unlimited - } - - def __init__(self, overrides: Optional[Dict[str, Optional[int]]] = None): - merged = dict(self.DEFAULT_LIMITS) - if overrides: - merged.update(overrides) - self.limits = merged - - def is_allowed(self, domain: str, active_count: int) -> bool: - limit = self.limits.get(domain) - if limit is None: - return True - return active_count < limit - - -class PortAllocator: - """Simple round-robin port allocator with reuse tracking.""" - - def __init__(self, start: int, end: int): - if start < 1 or end > 65535 or start > end: - raise ValueError("Invalid port range") - self.start = start - self.end = end - self._next = start - self._in_use: Dict[int, str] = {} - - def allocate(self, tunnel_id: str) -> Optional[int]: - for _ in range(self.start, self.end + 1): - candidate = self._next - self._next += 1 - if self._next > self.end: - self._next = self.start - if candidate in self._in_use: - continue - self._in_use[candidate] = tunnel_id - return candidate - return None - - def release(self, port: int) -> None: - self._in_use.pop(port, None) - - def in_use(self) -> Dict[int, str]: - return dict(self._in_use) - - -class TunnelLeaseManager: - """DHCP-like lease manager for reverse tunnels (Engine side).""" - - def __init__( - self, - *, - port_range: Tuple[int, int], - idle_timeout_seconds: int, - grace_timeout_seconds: int, - domain_policy: Optional[DomainPolicy] = None, - logger: Optional[logging.Logger] = None, - ): - self._allocator = PortAllocator(port_range[0], port_range[1]) - self.idle_timeout_seconds = idle_timeout_seconds - self.grace_timeout_seconds = grace_timeout_seconds - self.domain_policy = domain_policy or DomainPolicy() - self.logger = logger or logging.getLogger("borealis.engine.tunnel.lease") - self._leases: Dict[str, TunnelLease] = {} - - def _active_for_agent_domain(self, agent_id: str, domain: str) -> int: - active_states = {"pending", "active", "closing"} - return sum( - 1 - for lease in self._leases.values() - if lease.agent_id == agent_id and lease.domain == domain and lease.state in active_states - ) - - def allocate( - self, - *, - agent_id: str, - protocol: str, - domain: str, - operator_id: Optional[str], - token: Optional[str] = None, - ) -> TunnelLease: - in_domain = self._active_for_agent_domain(agent_id, domain) - if not self.domain_policy.is_allowed(domain, in_domain): - raise RuntimeError(f"domain_limit:{domain}") - - tunnel_id = _generate_tunnel_id() - port = self._allocator.allocate(tunnel_id) - if port is None: - raise RuntimeError("port_pool_exhausted") - - now_ts = _utc_ts() - lease = TunnelLease( - tunnel_id=tunnel_id, - agent_id=agent_id, - domain=domain, - protocol=protocol, - operator_id=operator_id, - assigned_port=port, - token=token, - created_at=now_ts, - expires_at=now_ts + self.grace_timeout_seconds, - idle_timeout_seconds=self.idle_timeout_seconds, - grace_timeout_seconds=self.grace_timeout_seconds, - state="pending", - last_activity_ts=now_ts, - ) - self._leases[tunnel_id] = lease - self.logger.info( - "lease_allocated tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s", - tunnel_id, - agent_id, - domain, - protocol, - port, - ) - return lease - - def release(self, tunnel_id: str, *, reason: str = "released") -> None: - lease = self._leases.pop(tunnel_id, None) - if lease is None: - return - self._allocator.release(lease.assigned_port) - self.logger.info( - "lease_released tunnel_id=%s agent_id=%s port=%s reason=%s", - tunnel_id, - lease.agent_id, - lease.assigned_port, - reason, - ) - - def get(self, tunnel_id: str) -> Optional[TunnelLease]: - return self._leases.get(tunnel_id) - - def touch(self, tunnel_id: str) -> None: - lease = self._leases.get(tunnel_id) - if lease: - lease.touch() - - def mark_agent_connected(self, tunnel_id: str) -> None: - lease = self._leases.get(tunnel_id) - if lease: - lease.mark_active() - - def mark_agent_disconnected(self, tunnel_id: str) -> None: - lease = self._leases.get(tunnel_id) - if lease: - lease.mark_disconnected() - - def expire_idle(self, *, now_ts: Optional[float] = None) -> List[TunnelLease]: - now = now_ts or _utc_ts() - expired: List[TunnelLease] = [] - pending_timeout = min(self.grace_timeout_seconds, 300) # avoid long-lived pending locks - for lease in list(self._leases.values()): - if lease.state == "expired": - continue - - idle_age = now - lease.last_activity_ts - pending_age = now - lease.created_at - if lease.state == "active" and idle_age >= lease.idle_timeout_seconds: - lease.mark_expired() - expired.append(lease) - self.release(lease.tunnel_id, reason="idle_timeout") - continue - - if lease.agent_disconnected_at: - grace_age = now - lease.agent_disconnected_at - if grace_age >= lease.grace_timeout_seconds: - lease.mark_expired() - expired.append(lease) - self.release(lease.tunnel_id, reason="grace_expired") - continue - - if lease.state == "pending": - hard_expiry = lease.expires_at or (lease.created_at + lease.grace_timeout_seconds) - if pending_age >= pending_timeout or (hard_expiry and now >= hard_expiry): - lease.mark_expired() - expired.append(lease) - self.release(lease.tunnel_id, reason="pending_timeout") - continue - return expired - - def all_leases(self) -> Iterable[TunnelLease]: - return list(self._leases.values()) - - -class ReverseTunnelService: - """Placeholder for the async tunnel listener and bridge wiring.""" - - def __init__( - self, - context: EngineContext, - *, - signer: Optional[object] = None, - db_conn_factory: Optional[Callable[[], object]] = None, - socketio: Optional[object] = None, - ): - self.context = context - self.logger = context.logger.getChild("tunnel.service") - self.audit_logger = _build_tunnel_logger(Path(context.reverse_tunnel_log_path)) - self.lease_manager = TunnelLeaseManager( - port_range=context.reverse_tunnel_port_range, - idle_timeout_seconds=context.reverse_tunnel_idle_timeout_seconds, - grace_timeout_seconds=context.reverse_tunnel_grace_timeout_seconds, - logger=self.audit_logger.getChild("lease_manager"), - ) - self._activity_logger = self.audit_logger.getChild("device_activity") - self._db_conn_factory = db_conn_factory - self._socketio = socketio - self.fixed_port = context.reverse_tunnel_fixed_port - self.heartbeat_seconds = context.reverse_tunnel_heartbeat_seconds - self.log_path = Path(context.reverse_tunnel_log_path) - self._loop: Optional[asyncio.AbstractEventLoop] = None - self._loop_thread: Optional[Thread] = None - self._running = False - self._sweeper_task: Optional[asyncio.Future] = None - self.signer = signer - self._bridges: Dict[str, "TunnelBridge"] = {} - self._port_servers: Dict[int, asyncio.AbstractServer] = {} - self._agent_sockets: Dict[str, "websockets.WebSocketServerProtocol"] = {} - self.protocol_registry = { - "ps": PowershellChannelServer, - "powershell": PowershellChannelServer, - "bash": BashChannelServer, - "ssh": SSHChannelServer, - "winrm": WinRMChannelServer, - "vnc": VNCChannelServer, - "rdp": RDPChannelServer, - "webrtc": WebRTCChannelServer, - } - self._protocol_servers: Dict[str, object] = {} - - def _ensure_loop(self) -> None: - if self._running and self._loop: - return - self._loop = asyncio.new_event_loop() - self._running = True - - def _runner(): - asyncio.set_event_loop(self._loop) - self.logger.info( - "Reverse tunnel event loop started (fixed_port=%s port_range=%s-%s)", - self.fixed_port, - self.lease_manager._allocator.start, - self.lease_manager._allocator.end, - ) - self._loop.run_forever() - - self._loop_thread = Thread(target=_runner, name="reverse-tunnel-loop", daemon=True) - self._loop_thread.start() - self._start_lease_sweeper() - - def start(self) -> None: - """Start the tunnel service loop.""" - - if self._running: - return - self._ensure_loop() - - def stop(self) -> None: - """Stop the tunnel service and release leases.""" - - if not self._running: - return - for server in list(self._port_servers.values()): - try: - server.close() - except Exception: - pass - self._port_servers.clear() - for websocket in list(self._agent_sockets.values()): - try: - self._loop.call_soon_threadsafe(asyncio.create_task, websocket.close()) - except Exception: - pass - for tunnel_id in list(self._bridges.keys()): - try: - self.release_bridge(tunnel_id, reason="service_stop") - except Exception: - pass - self._protocol_servers.clear() - for lease in list(self.lease_manager.all_leases()): - self.lease_manager.release(lease.tunnel_id, reason="service_stop") - if self._sweeper_task: - try: - self._sweeper_task.cancel() - except Exception: - pass - self._running = False - if self._loop: - self._loop.call_soon_threadsafe(self._loop.stop) - self.logger.info("Reverse tunnel service stopped.") - - async def start_listener(self) -> None: - """Placeholder async listener hook (no sockets yet).""" - - if not self._running: - self.start() - self.logger.debug("Reverse tunnel async listener placeholder running (no sockets bound).") - - async def handle_agent_connect(self, tunnel_id: str, token: str) -> TunnelBridge: - """Validate agent token and attach to bridge (socket handling TBD).""" - - lease = self.lease_manager.get(tunnel_id) - if lease is None: - raise ValueError("unknown_tunnel") - bridge = self.ensure_bridge(lease) - bridge.attach_agent(token) - return bridge - - async def handle_operator_connect(self, tunnel_id: str, operator_id: Optional[str]) -> TunnelBridge: - """Attach operator to bridge (socket handling TBD).""" - - lease = self.lease_manager.get(tunnel_id) - if lease is None: - raise ValueError("unknown_tunnel") - bridge = self.ensure_bridge(lease) - bridge.attach_operator(operator_id) - return bridge - - def agent_attach(self, tunnel_id: str, token: str) -> TunnelBridge: - """Synchronous wrapper for agent attachment.""" - - lease = self.lease_manager.get(tunnel_id) - if lease is None: - raise ValueError("unknown_tunnel") - bridge = self.ensure_bridge(lease) - bridge.attach_agent(token) - return bridge - - def operator_attach(self, tunnel_id: str, operator_id: Optional[str]) -> TunnelBridge: - """Synchronous wrapper for operator attachment.""" - - lease = self.lease_manager.get(tunnel_id) - if lease is None: - raise ValueError("unknown_tunnel") - bridge = self.ensure_bridge(lease) - bridge.attach_operator(operator_id) - if (lease.protocol or "").lower() in {"ps", "powershell"}: - try: - server = self.ensure_protocol_server(tunnel_id) - if server: - server.open_channel() - except Exception: - self.logger.debug("ps server open failed tunnel_id=%s", tunnel_id, exc_info=True) - return bridge - - def _encode_token(self, payload: Dict[str, object]) -> str: - """Encode a short-lived token binding the lease fields.""" - - payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") - payload_b64 = base64.urlsafe_b64encode(payload_bytes).decode("ascii").rstrip("=") - if self.signer: - try: - signature = self.signer.sign(payload_bytes) - sig_b64 = base64.urlsafe_b64encode(signature).decode("ascii").rstrip("=") - return f"{payload_b64}.{sig_b64}" - except Exception: - self.logger.debug("Reverse tunnel token signing failed; returning unsigned token", exc_info=True) - return payload_b64 - - def request_lease( - self, - *, - agent_id: str, - protocol: str, - domain: str, - operator_id: Optional[str], - ) -> TunnelLease: - self._ensure_loop() - lease = self.lease_manager.allocate( - agent_id=agent_id, - protocol=protocol, - domain=domain, - operator_id=operator_id, - ) - lease.token = self.issue_token(lease) - self._spawn_port_listener(lease.assigned_port) - self._push_start_to_agent(lease) - self.audit_logger.info( - "lease_created tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s operator=%s", - lease.tunnel_id, - lease.agent_id, - lease.domain, - lease.protocol, - lease.assigned_port, - operator_id or "-", - ) - return lease - - def issue_token(self, lease: TunnelLease) -> str: - expires_at = lease.created_at + lease.grace_timeout_seconds - payload = { - "agent_id": lease.agent_id, - "tunnel_id": lease.tunnel_id, - "assigned_port": lease.assigned_port, - "protocol": lease.protocol, - "domain": lease.domain, - "expires_at": int(expires_at), - "issued_at": int(lease.created_at), - } - token = self._encode_token(payload) - lease.token = token - lease.expires_at = expires_at - return token - - def stop_tunnel(self, tunnel_id: str, *, reason: str = "operator_stop", code: int = CLOSE_AGENT_SHUTDOWN) -> bool: - """Request a graceful stop for a tunnel (operator-driven).""" - - lease = self.lease_manager.get(tunnel_id) - if lease is None: - return False - server = self.get_protocol_server(tunnel_id) - if server and hasattr(server, "close"): - try: - server.close(code=code, reason=reason) - except Exception: - self.logger.debug("protocol server close failed tunnel_id=%s", tunnel_id, exc_info=True) - if tunnel_id in self._protocol_servers: - try: - self._protocol_servers.pop(tunnel_id, None) - except Exception: - pass - self._push_stop_to_agent(lease, reason=reason) - websocket = self._agent_sockets.pop(tunnel_id, None) - if websocket is not None: - try: - self.lease_manager.mark_agent_disconnected(tunnel_id) - except Exception: - pass - try: - if self._loop: - self._loop.call_soon_threadsafe(asyncio.create_task, websocket.close()) - except Exception: - self.logger.debug("agent websocket close failed tunnel_id=%s", tunnel_id, exc_info=True) - self.release_bridge(tunnel_id, reason=reason) - return True - - def _push_start_to_agent(self, lease: TunnelLease) -> None: - """Notify the target agent about the new lease over Socket.IO (best-effort).""" - - if not self._socketio: - return - payload = { - "tunnel_id": lease.tunnel_id, - "lease_id": lease.tunnel_id, - "agent_id": lease.agent_id, - "token": lease.token, - "port": lease.assigned_port, - "assigned_port": lease.assigned_port, - "protocol": lease.protocol, - "domain": lease.domain, - "idle_seconds": lease.idle_timeout_seconds, - "grace_seconds": lease.grace_timeout_seconds, - "heartbeat_seconds": self.heartbeat_seconds, - } - try: - self._socketio.emit("reverse_tunnel_start", payload, namespace="/") - self.audit_logger.info( - "lease_push_start tunnel_id=%s agent_id=%s port=%s", - lease.tunnel_id, - lease.agent_id, - lease.assigned_port, - ) - except Exception: - self.logger.debug("Failed to emit reverse_tunnel_start for tunnel_id=%s", lease.tunnel_id, exc_info=True) - - def _push_stop_to_agent(self, lease: TunnelLease, *, reason: str = "operator_stop") -> None: - """Notify the agent to tear down a tunnel (best-effort).""" - - if not self._socketio: - return - try: - self._socketio.emit( - "reverse_tunnel_stop", - {"tunnel_id": lease.tunnel_id, "reason": reason}, - namespace="/", - ) - self.audit_logger.info( - "lease_push_stop tunnel_id=%s agent_id=%s reason=%s", - lease.tunnel_id, - lease.agent_id, - reason or "-", - ) - except Exception: - self.logger.debug("Failed to emit reverse_tunnel_stop for tunnel_id=%s", lease.tunnel_id, exc_info=True) - - def lease_summary(self, lease: TunnelLease) -> Dict[str, object]: - return { - "tunnel_id": lease.tunnel_id, - "agent_id": lease.agent_id, - "protocol": lease.protocol, - "domain": lease.domain, - "port": lease.assigned_port, - "token": lease.token, - "expires_at": lease.expires_at, - "idle_seconds": lease.idle_timeout_seconds, - "grace_seconds": lease.grace_timeout_seconds, - "state": lease.state, - } - - def decode_token(self, token: str) -> Dict[str, object]: - """Decode and optionally verify a tunnel token (unsigned tokens allowed).""" - - if not token: - raise ValueError("token_missing") - - def _b64decode(segment: str) -> bytes: - padding = "=" * (-len(segment) % 4) - return base64.urlsafe_b64decode(segment + padding) - - parts = token.split(".") - payload_segment = parts[0] - payload_bytes = _b64decode(payload_segment) - try: - payload = json.loads(payload_bytes.decode("utf-8")) - except Exception as exc: - raise ValueError("token_decode_error") from exc - - # Optional signature verification if present and signer is available. - if len(parts) == 2 and self.signer: - sig_segment = parts[1] - try: - signature = _b64decode(sig_segment) - except Exception as exc: - raise ValueError("token_signature_decode_error") from exc - public_key = getattr(self.signer, "_public", None) - if public_key: - try: - public_key.verify(signature, payload_bytes) - except Exception as exc: - raise ValueError("token_signature_invalid") from exc - - return payload - - def validate_token( - self, - token: str, - *, - agent_id: Optional[str] = None, - tunnel_id: Optional[str] = None, - domain: Optional[str] = None, - protocol: Optional[str] = None, - ) -> Dict[str, object]: - """Validate a tunnel token against expected fields and expiry.""" - - payload = self.decode_token(token) - now = int(_utc_ts()) - - def _matches(expected: Optional[str], actual: Optional[str]) -> bool: - if expected is None: - return True - return str(expected).strip().lower() == str(actual or "").strip().lower() - - if not _matches(agent_id, payload.get("agent_id")): - raise ValueError("token_agent_mismatch") - if not _matches(tunnel_id, payload.get("tunnel_id")): - raise ValueError("token_id_mismatch") - if not _matches(domain, payload.get("domain")): - raise ValueError("token_domain_mismatch") - if not _matches(protocol, payload.get("protocol")): - raise ValueError("token_protocol_mismatch") - - expires_at = payload.get("expires_at") - try: - expires_ts = int(expires_at) if expires_at is not None else None - except Exception: - expires_ts = None - if expires_ts is not None and expires_ts < now: - raise ValueError("token_expired") - - return payload - - def log_device_activity( - self, - lease: TunnelLease, - *, - event: str, - reason: Optional[str] = None, - ) -> None: - """Device Activity logging for tunnel start/stop (DB + socket emit if available).""" - - agent_id = lease.agent_id - operator_id = lease.operator_id - tunnel_id = lease.tunnel_id - - if self._db_conn_factory is None: - self._activity_logger.info( - "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s", - event, - agent_id, - tunnel_id, - operator_id or "-", - reason or "-", - ) - return - - conn = None - try: - conn = self._db_conn_factory() - cur = conn.cursor() - - hostname = lease.hostname - if not hostname: - try: - cur.execute( - "SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1", - (agent_id,), - ) - row = cur.fetchone() - if row and row[0]: - hostname = str(row[0]).strip() - lease.hostname = hostname - except Exception: - hostname = None - - if not hostname: - self._activity_logger.info( - "device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown", - event, - agent_id, - tunnel_id, - operator_id or "-", - reason or "-", - ) - return - - now_ts = int(_utc_ts()) - script_name = f"Reverse Tunnel ({lease.domain}/{lease.protocol})" - - if event == "start": - cur.execute( - """ - INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr) - VALUES(?,?,?,?,?,?,?,?) - """, - ( - hostname, - lease.tunnel_id, - script_name, - "reverse_tunnel", - now_ts, - "Running", - "", - "", - ), - ) - lease.activity_id = cur.lastrowid - conn.commit() - if self._socketio: - try: - self._socketio.emit( - "device_activity_changed", - { - "hostname": hostname, - "activity_id": lease.activity_id, - "change": "created", - "source": "reverse_tunnel", - }, - ) - except Exception: - pass - self._activity_logger.info( - "device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s", - hostname, - agent_id, - tunnel_id, - operator_id or "-", - lease.activity_id or "-", - ) - return - - if lease.activity_id: - status = "Completed" if event == "stop" else "Closed" - cur.execute( - """ - UPDATE activity_history - SET status=?, - stderr=COALESCE(stderr, '') || ? - WHERE id=? - """, - ( - status, - f"\nreason: {reason}" if reason else "", - lease.activity_id, - ), - ) - conn.commit() - if self._socketio: - try: - self._socketio.emit( - "device_activity_changed", - { - "hostname": hostname, - "activity_id": lease.activity_id, - "change": "updated", - "source": "reverse_tunnel", - }, - ) - except Exception: - pass - self._activity_logger.info( - "device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s", - event, - hostname, - agent_id, - tunnel_id, - operator_id or "-", - reason or "-", - lease.activity_id or "-", - ) - except Exception: - self._activity_logger.debug("device_activity logging failed for tunnel_id=%s", lease.tunnel_id, exc_info=True) - finally: - if conn is not None: - try: - conn.close() - except Exception: - pass - - def _dispatch_agent_frame(self, tunnel_id: str, frame: TunnelFrame) -> None: - server = self._protocol_servers.get(tunnel_id) - if not server: - return - try: - server.handle_agent_frame(frame) - except Exception: - self.logger.debug("ps handler error for tunnel_id=%s", tunnel_id, exc_info=True) - - def _start_lease_sweeper(self) -> None: - async def _sweeper(): - while self._running and self._loop and not self._loop.is_closed(): - await asyncio.sleep(15) - expired = self.lease_manager.expire_idle() - for lease in expired: - self.log_device_activity(lease, event="stop", reason="idle_or_grace") - if self._loop: - self._sweeper_task = asyncio.run_coroutine_threadsafe(_sweeper(), self._loop) - - def _build_ssl_context(self) -> Optional[ssl.SSLContext]: - cert = self.context.tls_cert_path or self.context.tls_bundle_path - key = self.context.tls_key_path - if not cert or not key: - self.audit_logger.info("tunnel_listener_ssl_missing cert=%s key=%s", cert, key) - return None - try: - ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ctx.load_cert_chain(certfile=cert, keyfile=key) - return ctx - except Exception: - self.logger.debug("Failed to build SSL context for reverse tunnel listener", exc_info=True) - return None - - def _spawn_port_listener(self, port: int) -> None: - if ws_serve is None: - self.logger.error("websockets dependency missing; cannot start tunnel listener") - return - if port in self._port_servers: - return - ssl_ctx = self._build_ssl_context() - self.audit_logger.info("tunnel_listener_start port=%s ssl=%s", port, bool(ssl_ctx)) - - async def _handler(websocket, path): - await self._handle_agent_socket(websocket, path, port=port) - - async def _start(): - server = await ws_serve(_handler, host="0.0.0.0", port=port, ssl=ssl_ctx, max_size=None, ping_interval=None) - self._port_servers[port] = server - self.audit_logger.info("tunnel_listener_bound port=%s", port) - - asyncio.run_coroutine_threadsafe(_start(), self._loop) - - async def _handle_agent_socket(self, websocket, path: str, *, port: int) -> None: - """Handle agent tunnel socket on assigned port.""" - - tunnel_id = None - tunnel_stop_reason: Optional[str] = None - sock_log = self.audit_logger.getChild("agent_socket") - try: - peer = None - try: - peer = getattr(websocket, "remote_address", None) - except Exception: - peer = None - sock_log.info("agent_socket_open port=%s path=%s peer=%s", port, path, peer) - raw = await asyncio.wait_for(websocket.recv(), timeout=10) - frame = decode_frame(raw) - if frame.msg_type != MSG_CONNECT: - sock_log.info("agent_socket_first_frame_not_connect port=%s msg_type=%s", port, frame.msg_type) - await websocket.close() - return - try: - payload = json.loads(frame.payload.decode("utf-8")) - except Exception: - sock_log.info("agent_socket_connect_payload_decode_failed port=%s", port, exc_info=True) - await websocket.close() - return - tunnel_id = str(payload.get("tunnel_id") or "").strip() - agent_id = str(payload.get("agent_id") or "").strip() - token = payload.get("token") or "" - lease = self.lease_manager.get(tunnel_id) - if lease is None or lease.assigned_port != port: - sock_log.info( - "agent_socket_unknown_lease port=%s tunnel_id=%s assigned=%s expected=%s", - port, - tunnel_id, - lease.assigned_port if lease else None, - port, - ) - await websocket.close() - return - # Token validation - try: - self.validate_token( - token, - agent_id=agent_id, - tunnel_id=tunnel_id, - domain=lease.domain, - protocol=lease.protocol, - ) - sock_log.info( - "agent_socket_token_valid port=%s tunnel_id=%s agent_id=%s domain=%s protocol=%s", - port, - tunnel_id, - agent_id, - lease.domain, - lease.protocol, - ) - except Exception as exc: - sock_log.info( - "agent_socket_token_invalid port=%s tunnel_id=%s agent_id=%s error=%s", - port, - tunnel_id, - agent_id, - exc, - ) - await websocket.close() - return - bridge = self.ensure_bridge(lease) - bridge.attach_agent(token) - self._agent_sockets[tunnel_id] = websocket - await websocket.send(heartbeat_frame(channel_id=0, is_ack=True).encode()) - await websocket.send(TunnelFrame(msg_type=MSG_CONNECT_ACK, channel_id=0, payload=b"").encode()) - sock_log.info( - "agent_socket_connected port=%s tunnel_id=%s agent_id=%s", - port, - tunnel_id, - agent_id, - ) - - async def _pump_to_operator(): - nonlocal tunnel_stop_reason - sock_log_local = sock_log.getChild("recv") - while not websocket.closed: - try: - raw_msg = await websocket.recv() - except Exception: - break - try: - recv_frame = decode_frame(raw_msg) - except Exception: - sock_log_local.info("agent_socket_frame_decode_failed tunnel_id=%s", tunnel_id, exc_info=True) - continue - self.lease_manager.touch(tunnel_id) - sock_log_local.info( - "agent_to_operator tunnel_id=%s msg_type=%s channel=%s payload_len=%s", - tunnel_id, - recv_frame.msg_type, - recv_frame.channel_id, - len(recv_frame.payload or b""), - ) - if recv_frame.msg_type == MSG_CLOSE and recv_frame.channel_id == 0: - try: - close_info = json.loads(recv_frame.payload.decode("utf-8")) - except Exception: - close_info = {} - close_code = close_info.get("code") if isinstance(close_info, dict) else None - close_reason = close_info.get("reason") if isinstance(close_info, dict) else None - tunnel_stop_reason = (close_reason or "").strip() or ( - f"agent_close_code_{close_code}" if close_code is not None else "agent_close" - ) - sock_log_local.info( - "agent_close_frame tunnel_id=%s code=%s reason=%s", - tunnel_id, - close_code, - tunnel_stop_reason or "-", - ) - try: - self.lease_manager.mark_agent_disconnected(tunnel_id) - except Exception: - pass - bridge.agent_to_operator(recv_frame) - break - try: - self._dispatch_agent_frame(tunnel_id, recv_frame) - except Exception: - pass - bridge.agent_to_operator(recv_frame) - async def _pump_to_agent(): - sock_log_local = sock_log.getChild("send") - while not websocket.closed: - frame = bridge.next_for_agent() - if frame is None: - await asyncio.sleep(0.05) - continue - try: - await websocket.send(frame.encode()) - sock_log_local.info( - "operator_to_agent tunnel_id=%s msg_type=%s channel=%s payload_len=%s", - tunnel_id, - frame.msg_type, - frame.channel_id, - len(frame.payload or b""), - ) - except Exception: - break - async def _heartbeat(): - sock_log_local = sock_log.getChild("heartbeat") - while not websocket.closed: - try: - await websocket.send(heartbeat_frame(channel_id=0).encode()) - except Exception: - sock_log_local.info("heartbeat_send_failed tunnel_id=%s", tunnel_id, exc_info=True) - break - await asyncio.sleep(self.heartbeat_seconds) - - consumer = asyncio.create_task(_pump_to_operator()) - producer = asyncio.create_task(_pump_to_agent()) - heart = asyncio.create_task(_heartbeat()) - await asyncio.wait([consumer, producer, heart], return_when=asyncio.FIRST_COMPLETED) - except Exception: - sock_log.info("agent_socket_handler_failed port=%s tunnel_id=%s", port, tunnel_id, exc_info=True) - finally: - ws_close_reason = getattr(websocket, "close_reason", None) - ws_close_code = getattr(websocket, "close_code", None) - close_reason = tunnel_stop_reason or (ws_close_reason if ws_close_reason else None) - try: - sock_log.info( - "agent_socket_closed port=%s tunnel_id=%s code=%s reason=%s", - port, - tunnel_id, - ws_close_code, - close_reason, - ) - except Exception: - pass - if tunnel_id and tunnel_id in self._agent_sockets: - self._agent_sockets.pop(tunnel_id, None) - if tunnel_id: - try: - self.lease_manager.mark_agent_disconnected(tunnel_id) - except Exception: - pass - self.release_bridge(tunnel_id, reason=close_reason or "agent_socket_closed") - - def get_bridge(self, tunnel_id: str) -> Optional["TunnelBridge"]: - return self._bridges.get(tunnel_id) - - def ensure_bridge(self, lease: TunnelLease) -> "TunnelBridge": - bridge = self._bridges.get(lease.tunnel_id) - if bridge is None: - bridge = TunnelBridge(lease=lease, service=self) - self._bridges[lease.tunnel_id] = bridge - return bridge - - def ensure_protocol_server(self, tunnel_id: str) -> Optional[object]: - server = self._protocol_servers.get(tunnel_id) - if server: - return server - lease = self.lease_manager.get(tunnel_id) - if lease is None: - return None - handler_cls = self.protocol_registry.get((lease.protocol or "").lower()) - if handler_cls is None: - return None - bridge = self.ensure_bridge(lease) - try: - server = handler_cls( - bridge=bridge, - service=self, - frame_cls=TunnelFrame, - close_frame_fn=close_frame, - ) - except TypeError: - server = handler_cls(bridge=bridge, service=self) - self._protocol_servers[tunnel_id] = server - return server - - def get_protocol_server(self, tunnel_id: str) -> Optional[object]: - return self._protocol_servers.get(tunnel_id) - - def release_bridge(self, tunnel_id: str, *, reason: str = "bridge_released") -> None: - bridge = self._bridges.pop(tunnel_id, None) - if bridge: - bridge.stop(reason=reason) - if tunnel_id in self._protocol_servers: - try: - self._protocol_servers.pop(tunnel_id, None) - except Exception: - pass - - -class TunnelBridge: - """Lightweight placeholder for mapping agent and operator sockets.""" - - def __init__(self, *, lease: TunnelLease, service: ReverseTunnelService): - self.lease = lease - self.service = service - self.logger = service.logger.getChild(f"bridge.{lease.tunnel_id}") - self.agent_connected = False - self.operator_attached = False - self._agent_queue: Deque[TunnelFrame] = deque() - self._operator_queue: Deque[TunnelFrame] = deque() - self._closed = False - - def attach_agent(self, token: str) -> None: - """Validate the agent token and mark the lease active (no socket binding yet).""" - - self.service.validate_token( - token, - agent_id=self.lease.agent_id, - tunnel_id=self.lease.tunnel_id, - domain=self.lease.domain, - protocol=self.lease.protocol, - ) - self.lease.mark_active() - self.service.lease_manager.mark_agent_connected(self.lease.tunnel_id) - self.agent_connected = True - self.service.log_device_activity(self.lease, event="start") - self.logger.info("agent_connected tunnel_id=%s agent_id=%s", self.lease.tunnel_id, self.lease.agent_id) - - def attach_operator(self, operator_id: Optional[str]) -> None: - self.operator_attached = True - if operator_id: - self.lease.operator_id = operator_id - self.logger.info("operator_attached tunnel_id=%s operator=%s", self.lease.tunnel_id, operator_id or "-") - - def stop(self, *, reason: str = "stopped") -> None: - self.service.lease_manager.release(self.lease.tunnel_id, reason=reason) - self.service.log_device_activity(self.lease, event="stop", reason=reason) - self.logger.info( - "bridge_stopped tunnel_id=%s agent_id=%s reason=%s", - self.lease.tunnel_id, - self.lease.agent_id, - reason, - ) - self._closed = True - - def agent_to_operator(self, frame: TunnelFrame) -> None: - """Queue a frame from agent toward operator.""" - - if self._closed: - return - self._operator_queue.append(frame) - - def operator_to_agent(self, frame: TunnelFrame) -> None: - """Queue a frame from operator toward agent.""" - - if self._closed: - return - try: - self.service.lease_manager.touch(self.lease.tunnel_id) - except Exception: - pass - self._agent_queue.append(frame) - - def next_for_agent(self) -> Optional[TunnelFrame]: - if self._closed or not self._agent_queue: - return None - return self._agent_queue.popleft() - - def next_for_operator(self) -> Optional[TunnelFrame]: - if self._closed or not self._operator_queue: - return None - return self._operator_queue.popleft() - - -__all__ = [ - "ReverseTunnelService", - "TunnelLeaseManager", - "TunnelLease", - "DomainPolicy", - "PortAllocator", - "TunnelBridge", - "TunnelFrame", - "decode_frame", - "heartbeat_frame", - "close_frame", - "FrameDecodeError", - "FrameValidationError", - "MSG_CONNECT", - "MSG_CONNECT_ACK", - "MSG_CHANNEL_OPEN", - "MSG_CHANNEL_ACK", - "MSG_DATA", - "MSG_WINDOW_UPDATE", - "MSG_HEARTBEAT", - "MSG_CLOSE", - "MSG_CONTROL", - "CLOSE_OK", - "CLOSE_IDLE_TIMEOUT", - "CLOSE_GRACE_EXPIRED", - "CLOSE_PROTOCOL_ERROR", - "CLOSE_AUTH_FAILED", - "CLOSE_SERVER_SHUTDOWN", - "CLOSE_AGENT_SHUTDOWN", - "CLOSE_DOMAIN_LIMIT", - "CLOSE_UNEXPECTED_DISCONNECT", -] diff --git a/Data/Engine/services/WebSocket/__init__.py b/Data/Engine/services/WebSocket/__init__.py index 358d1a72..2c4f7445 100644 --- a/Data/Engine/services/WebSocket/__init__.py +++ b/Data/Engine/services/WebSocket/__init__.py @@ -1,6 +1,6 @@ # ====================================================== # Data\Engine\services\WebSocket\__init__.py -# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications. +# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging. # # API Endpoints (if applicable): None # ====================================================== @@ -8,24 +8,20 @@ """WebSocket service registration for the Borealis Engine runtime.""" from __future__ import annotations -import base64 import sqlite3 import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Dict, Optional -from flask import session, request +from flask import request from flask_socketio import SocketIO from ...database import initialise_engine_database +from ...security import signing from ...server import EngineContext -from .Agent.reverse_tunnel_orchestrator import ( - ReverseTunnelService, - TunnelBridge, - decode_frame, - TunnelFrame, -) +from ..VPN import VpnTunnelService +from .vpn_shell import VpnShellBridge def _now_ts() -> int: @@ -70,20 +66,31 @@ class EngineRealtimeAdapters: def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: """Register Socket.IO event handlers for the Engine runtime.""" - from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load - adapters = EngineRealtimeAdapters(context) logger = context.logger.getChild("realtime.quick_jobs") - tunnel_service = getattr(context, "reverse_tunnel_service", None) - if tunnel_service is None: - tunnel_service = ReverseTunnelService( - context, - signer=None, + shell_bridge = VpnShellBridge(socket_server, context) + + def _get_tunnel_service() -> Optional[VpnTunnelService]: + service = getattr(context, "vpn_tunnel_service", None) + if service is not None: + return service + manager = getattr(context, "wireguard_server_manager", None) + if manager is None: + return None + try: + signer = signing.load_signer() + except Exception: + signer = None + service = VpnTunnelService( + context=context, + wireguard_manager=manager, db_conn_factory=adapters.db_conn_factory, socketio=socket_server, + service_log=adapters.service_log, + signer=signer, ) - tunnel_service.start() - setattr(context, "reverse_tunnel_service", tunnel_service) + setattr(context, "vpn_tunnel_service", service) + return service @socket_server.on("quick_job_result") def _handle_quick_job_result(data: Any) -> None: @@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: exc, ) - @socket_server.on("tunnel_bridge_attach") - def _tunnel_bridge_attach(data: Any) -> Any: - """Placeholder operator bridge attach handler (no data channel yet).""" + @socket_server.on("vpn_shell_open") + def _vpn_shell_open(data: Any) -> Dict[str, Any]: + agent_id = "" + if isinstance(data, dict): + agent_id = str(data.get("agent_id") or "").strip() + elif isinstance(data, str): + agent_id = data.strip() + if not agent_id: + return {"error": "agent_id_required"} - if not isinstance(data, dict): - return {"error": "invalid_payload"} + service = _get_tunnel_service() + if service is None: + return {"error": "vpn_service_unavailable"} + if not service.status(agent_id): + return {"error": "tunnel_down"} - tunnel_id = str(data.get("tunnel_id") or "").strip() - operator_id = str(data.get("operator_id") or "").strip() or None - if not tunnel_id: - return {"error": "tunnel_id_required"} + session = shell_bridge.open_session(request.sid, agent_id) + if session is None: + return {"error": "shell_connect_failed"} + service.bump_activity(agent_id) + return {"status": "ok"} - try: - tunnel_service.operator_attach(tunnel_id, operator_id) - except ValueError as exc: - return {"error": str(exc)} - except Exception as exc: # pragma: no cover - defensive guard - logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "bridge_attach_failed"} - - return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"} - - def _encode_frame(frame: TunnelFrame) -> str: - return base64.b64encode(frame.encode()).decode("ascii") - - def _decode_frame_payload(raw: Any) -> TunnelFrame: - if isinstance(raw, str): - try: - raw_bytes = base64.b64decode(raw) - except Exception: - raise ValueError("invalid_frame") - elif isinstance(raw, (bytes, bytearray)): - raw_bytes = bytes(raw) + @socket_server.on("vpn_shell_send") + def _vpn_shell_send(data: Any) -> Dict[str, Any]: + payload = None + if isinstance(data, dict): + payload = data.get("data") else: - raise ValueError("invalid_frame") - return decode_frame(raw_bytes) - - @socket_server.on("tunnel_operator_send") - def _tunnel_operator_send(data: Any) -> Any: - """Operator -> agent frame enqueue (placeholder queue).""" - - if not isinstance(data, dict): - return {"error": "invalid_payload"} - tunnel_id = str(data.get("tunnel_id") or "").strip() - frame_raw = data.get("frame") - if not tunnel_id or frame_raw is None: - return {"error": "tunnel_id_and_frame_required"} - try: - frame = _decode_frame_payload(frame_raw) - except Exception as exc: - return {"error": str(exc)} - - bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) - if bridge is None: - return {"error": "unknown_tunnel"} - bridge.operator_to_agent(frame) - return {"status": "ok"} - - @socket_server.on("tunnel_operator_poll") - def _tunnel_operator_poll(data: Any) -> Any: - """Operator polls queued frames from agent.""" - - tunnel_id = "" - if isinstance(data, dict): - tunnel_id = str(data.get("tunnel_id") or "").strip() - if not tunnel_id: - return {"error": "tunnel_id_required"} - bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id) - if bridge is None: - return {"error": "unknown_tunnel"} - - frames = [] - while True: - frame = bridge.next_for_operator() - if frame is None: - break - frames.append(_encode_frame(frame)) - return {"frames": frames} - - # WebUI operator bridge namespace for browser clients - tunnel_namespace = "/tunnel" - _operator_sessions: Dict[str, str] = {} - - def _current_operator() -> Optional[str]: - username = session.get("username") - if username: - return str(username) - auth_header = (request.headers.get("Authorization") or "").strip() - token = None - if auth_header.lower().startswith("bearer "): - token = auth_header.split(" ", 1)[1].strip() - if not token: - token = request.cookies.get("borealis_auth") - return token or None - - @socket_server.on("join", namespace=tunnel_namespace) - def _ws_tunnel_join(data: Any) -> Any: - if not isinstance(data, dict): - return {"error": "invalid_payload"} - operator_id = _current_operator() - if not operator_id: - return {"error": "unauthorized"} - tunnel_id = str(data.get("tunnel_id") or "").strip() - if not tunnel_id: - return {"error": "tunnel_id_required"} - bridge = tunnel_service.get_bridge(tunnel_id) - if bridge is None: - return {"error": "unknown_tunnel"} - try: - tunnel_service.operator_attach(tunnel_id, operator_id) - except Exception as exc: - logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "attach_failed"} - sid = request.sid - _operator_sessions[sid] = tunnel_id - return {"status": "ok", "tunnel_id": tunnel_id} - - @socket_server.on("send", namespace=tunnel_namespace) - def _ws_tunnel_send(data: Any) -> Any: - sid = request.sid - tunnel_id = _operator_sessions.get(sid) - if not tunnel_id: - return {"error": "not_joined"} - if not isinstance(data, dict): - return {"error": "invalid_payload"} - frame_raw = data.get("frame") - if frame_raw is None: - return {"error": "frame_required"} - try: - frame = _decode_frame_payload(frame_raw) - except Exception: - return {"error": "invalid_frame"} - bridge = tunnel_service.get_bridge(tunnel_id) - if bridge is None: - return {"error": "unknown_tunnel"} - bridge.operator_to_agent(frame) - return {"status": "ok"} - - @socket_server.on("poll", namespace=tunnel_namespace) - def _ws_tunnel_poll() -> Any: - sid = request.sid - tunnel_id = _operator_sessions.get(sid) - if not tunnel_id: - return {"error": "not_joined"} - bridge = tunnel_service.get_bridge(tunnel_id) - if bridge is None: - return {"error": "unknown_tunnel"} - frames = [] - while True: - frame = bridge.next_for_operator() - if frame is None: - break - frames.append(_encode_frame(frame)) - return {"frames": frames} - - def _require_ps_server(): - sid = request.sid - tunnel_id = _operator_sessions.get(sid) - if not tunnel_id: - return None, None, {"error": "not_joined"} - server = tunnel_service.ensure_protocol_server(tunnel_id) - if server is None or not hasattr(server, "open_channel"): - return None, tunnel_id, {"error": "ps_unsupported"} - return server, tunnel_id, None - - @socket_server.on("ps_open", namespace=tunnel_namespace) - def _ws_ps_open(data: Any) -> Any: - server, tunnel_id, error = _require_ps_server() - if server is None: - return error - cols = 120 - rows = 32 - if isinstance(data, dict): - try: - cols = int(data.get("cols", cols)) - rows = int(data.get("rows", rows)) - except Exception: - pass - cols = max(20, min(cols, 300)) - rows = max(10, min(rows, 200)) - try: - server.open_channel(cols=cols, rows=rows) - except Exception as exc: - logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "ps_open_failed"} - return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows} - - @socket_server.on("ps_send", namespace=tunnel_namespace) - def _ws_ps_send(data: Any) -> Any: - server, tunnel_id, error = _require_ps_server() - if server is None: - return error - if data is None: + payload = data + if payload is None: return {"error": "payload_required"} - text = data - if isinstance(data, dict): - text = data.get("data") - if text is None: - return {"error": "payload_required"} - try: - server.send_input(str(text)) - except Exception as exc: - logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "ps_send_failed"} + shell_bridge.send(request.sid, str(payload)) return {"status": "ok"} - @socket_server.on("ps_resize", namespace=tunnel_namespace) - def _ws_ps_resize(data: Any) -> Any: - server, tunnel_id, error = _require_ps_server() - if server is None: - return error - cols = None - rows = None - if isinstance(data, dict): - cols = data.get("cols") - rows = data.get("rows") - try: - cols_int = int(cols) if cols is not None else 120 - rows_int = int(rows) if rows is not None else 32 - cols_int = max(20, min(cols_int, 300)) - rows_int = max(10, min(rows_int, 200)) - server.send_resize(cols_int, rows_int) - return {"status": "ok", "cols": cols_int, "rows": rows_int} - except Exception as exc: - logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "ps_resize_failed"} + @socket_server.on("vpn_shell_close") + def _vpn_shell_close() -> Dict[str, Any]: + shell_bridge.close(request.sid) + return {"status": "ok"} - @socket_server.on("ps_poll", namespace=tunnel_namespace) - def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused - server, tunnel_id, error = _require_ps_server() - if server is None: - return error - try: - output = server.drain_output() - status = server.status() - return {"output": output, "status": status} - except Exception as exc: - logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) - return {"error": "ps_poll_failed"} - - @socket_server.on("disconnect", namespace=tunnel_namespace) - def _ws_tunnel_disconnect(): - sid = request.sid - tunnel_id = _operator_sessions.pop(sid, None) - if tunnel_id and tunnel_id not in _operator_sessions.values(): - try: - tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect") - except Exception as exc: - logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) + @socket_server.on("disconnect") + def _ws_disconnect() -> None: + shell_bridge.close(request.sid) diff --git a/Data/Engine/services/WebSocket/vpn_shell.py b/Data/Engine/services/WebSocket/vpn_shell.py new file mode 100644 index 00000000..c36b081b --- /dev/null +++ b/Data/Engine/services/WebSocket/vpn_shell.py @@ -0,0 +1,127 @@ +# ====================================================== +# Data\Engine\services\WebSocket\vpn_shell.py +# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard. +# +# API Endpoints (if applicable): None +# ====================================================== + +"""WireGuard VPN PowerShell bridge (Engine side).""" + +from __future__ import annotations + +import base64 +import json +import socket +import threading +from dataclasses import dataclass +from typing import Any, Dict, Optional + + +def _b64encode(data: bytes) -> str: + return base64.b64encode(data).decode("ascii").strip() + + +def _b64decode(value: str) -> bytes: + return base64.b64decode(value.encode("ascii")) + + +@dataclass +class ShellSession: + sid: str + agent_id: str + socketio: Any + tcp: socket.socket + _reader: Optional[threading.Thread] = None + + def start_reader(self) -> None: + t = threading.Thread(target=self._read_loop, daemon=True) + t.start() + self._reader = t + + def _read_loop(self) -> None: + buffer = b"" + try: + while True: + data = self.tcp.recv(4096) + if not data: + break + buffer += data + while b"\n" in buffer: + line, buffer = buffer.split(b"\n", 1) + if not line: + continue + try: + msg = json.loads(line.decode("utf-8")) + except Exception: + continue + if msg.get("type") == "stdout": + payload = msg.get("data") or "" + try: + decoded = _b64decode(str(payload)).decode("utf-8", errors="replace") + except Exception: + decoded = "" + self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid) + finally: + self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid) + try: + self.tcp.close() + except Exception: + pass + + def send(self, payload: str) -> None: + data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))}) + self.tcp.sendall(data.encode("utf-8") + b"\n") + + def close(self) -> None: + try: + data = json.dumps({"type": "close"}) + self.tcp.sendall(data.encode("utf-8") + b"\n") + except Exception: + pass + try: + self.tcp.close() + except Exception: + pass + + +class VpnShellBridge: + def __init__(self, socketio, context) -> None: + self.socketio = socketio + self.context = context + self._sessions: Dict[str, ShellSession] = {} + self.logger = context.logger.getChild("vpn_shell") + + def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]: + service = getattr(self.context, "vpn_tunnel_service", None) + if service is None: + return None + status = service.status(agent_id) + if not status: + return None + host = str(status.get("virtual_ip") or "").split("/")[0] + port = int(self.context.wireguard_shell_port) + try: + tcp = socket.create_connection((host, port), timeout=5) + except Exception: + self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True) + return None + session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp) + self._sessions[sid] = session + session.start_reader() + return session + + def send(self, sid: str, payload: str) -> None: + session = self._sessions.get(sid) + if not session: + return + session.send(payload) + service = getattr(self.context, "vpn_tunnel_service", None) + if service: + service.bump_activity(session.agent_id) + + def close(self, sid: str) -> None: + session = self._sessions.pop(sid, None) + if not session: + return + session.close() + diff --git a/Data/Engine/web-interface/src/Devices/Device_Details.jsx b/Data/Engine/web-interface/src/Devices/Device_Details.jsx index dc1962f0..9beced4e 100644 --- a/Data/Engine/web-interface/src/Devices/Device_Details.jsx +++ b/Data/Engine/web-interface/src/Devices/Device_Details.jsx @@ -8,13 +8,17 @@ import { Tab, Typography, Button, + Switch, + Chip, + Divider, Menu, MenuItem, TextField, Dialog, DialogTitle, DialogContent, - DialogActions + DialogActions, + LinearProgress } from "@mui/material"; import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; import StorageRoundedIcon from "@mui/icons-material/StorageRounded"; @@ -23,6 +27,7 @@ import LanRoundedIcon from "@mui/icons-material/LanRounded"; import AppsRoundedIcon from "@mui/icons-material/AppsRounded"; import ListAltRoundedIcon from "@mui/icons-material/ListAltRounded"; import TerminalRoundedIcon from "@mui/icons-material/TerminalRounded"; +import TuneRoundedIcon from "@mui/icons-material/TuneRounded"; import SpeedRoundedIcon from "@mui/icons-material/SpeedRounded"; import DeveloperBoardRoundedIcon from "@mui/icons-material/DeveloperBoardRounded"; import MoreHorizIcon from "@mui/icons-material/MoreHoriz"; @@ -69,14 +74,51 @@ const SECTION_HEIGHTS = { network: 260, }; +const buildVpnGroups = (shellPort) => { + const normalizedShell = Number(shellPort) || 47001; + return [ + { + key: "shell", + label: "Borealis PowerShell", + description: "Web terminal access over the VPN tunnel.", + ports: [normalizedShell], + }, + { + key: "rdp", + label: "RDP", + description: "Remote Desktop (TCP 3389).", + ports: [3389], + }, + { + key: "winrm", + label: "WinRM", + description: "PowerShell/WinRM management (TCP 5985/5986).", + ports: [5985, 5986], + }, + { + key: "vnc", + label: "VNC", + description: "Remote desktop streaming (TCP 5900).", + ports: [5900], + }, + { + key: "webrtc", + label: "WebRTC", + description: "Real-time comms (UDP 3478).", + ports: [3478], + }, + ]; +}; + const TOP_TABS = [ - { label: "Device Summary", icon: InfoOutlinedIcon }, - { label: "Storage", icon: StorageRoundedIcon }, - { label: "Memory", icon: MemoryRoundedIcon }, - { label: "Network", icon: LanRoundedIcon }, - { label: "Installed Software", icon: AppsRoundedIcon }, - { label: "Activity History", icon: ListAltRoundedIcon }, - { label: "Remote Shell", icon: TerminalRoundedIcon }, + { key: "summary", label: "Device Summary", icon: InfoOutlinedIcon }, + { key: "storage", label: "Storage", icon: StorageRoundedIcon }, + { key: "memory", label: "Memory", icon: MemoryRoundedIcon }, + { key: "network", label: "Network", icon: LanRoundedIcon }, + { key: "software", label: "Installed Software", icon: AppsRoundedIcon }, + { key: "activity", label: "Activity History", icon: ListAltRoundedIcon }, + { key: "advanced", label: "Advanced Config", icon: TuneRoundedIcon }, + { key: "shell", label: "Remote Shell", icon: TerminalRoundedIcon }, ]; const myTheme = themeQuartz.withParams({ @@ -286,6 +328,15 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage const [menuAnchor, setMenuAnchor] = useState(null); const [clearDialogOpen, setClearDialogOpen] = useState(false); const [assemblyNameMap, setAssemblyNameMap] = useState({}); + const [vpnLoading, setVpnLoading] = useState(false); + const [vpnSaving, setVpnSaving] = useState(false); + const [vpnError, setVpnError] = useState(""); + const [vpnSource, setVpnSource] = useState("default"); + const [vpnToggles, setVpnToggles] = useState({}); + const [vpnCustomPorts, setVpnCustomPorts] = useState([]); + const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]); + const [vpnShellPort, setVpnShellPort] = useState(47001); + const [vpnLoadedFor, setVpnLoadedFor] = useState(""); // Snapshotted status for the lifetime of this page const [lockedStatus, setLockedStatus] = useState(() => { // Prefer status provided by the device list row if available @@ -655,6 +706,104 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage }; }, [activityHostname, loadHistory]); + const applyVpnPorts = useCallback((ports, defaults, shellPort, source) => { + const normalized = Array.isArray(ports) ? ports : []; + const normalizedDefaults = Array.isArray(defaults) ? defaults : []; + const numericPorts = normalized + .map((p) => Number(p)) + .filter((p) => Number.isFinite(p) && p > 0); + const numericDefaults = normalizedDefaults + .map((p) => Number(p)) + .filter((p) => Number.isFinite(p) && p > 0); + const effectiveShell = Number(shellPort) || 47001; + const groups = buildVpnGroups(effectiveShell); + const knownPorts = new Set(groups.flatMap((group) => group.ports)); + const allowedSet = new Set(numericPorts); + const nextToggles = {}; + groups.forEach((group) => { + nextToggles[group.key] = group.ports.every((port) => allowedSet.has(port)); + }); + const customPorts = numericPorts.filter((port) => !knownPorts.has(port)); + setVpnShellPort(effectiveShell); + setVpnDefaultPorts(numericDefaults); + setVpnCustomPorts(customPorts); + setVpnToggles(nextToggles); + setVpnSource(source || "default"); + }, []); + + const loadVpnConfig = useCallback(async () => { + if (!vpnAgentId) return; + setVpnLoading(true); + setVpnError(""); + setVpnLoadedFor(vpnAgentId); + try { + const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`); + const data = await resp.json().catch(() => ({})); + if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`); + const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : []; + const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : []; + const shellPort = data?.shell_port; + applyVpnPorts(allowedPorts.length ? allowedPorts : defaultPorts, defaultPorts, shellPort, data?.source); + setVpnLoadedFor(vpnAgentId); + } catch (err) { + setVpnError(String(err.message || err)); + } finally { + setVpnLoading(false); + } + }, [applyVpnPorts, vpnAgentId]); + + const saveVpnConfig = useCallback(async () => { + if (!vpnAgentId) return; + const ports = []; + vpnPortGroups.forEach((group) => { + if (vpnToggles[group.key]) { + ports.push(...group.ports); + } + }); + vpnCustomPorts.forEach((port) => ports.push(port)); + const uniquePorts = Array.from(new Set(ports)).filter((p) => p > 0); + if (!uniquePorts.length) { + setVpnError("Enable at least one port before saving."); + return; + } + setVpnSaving(true); + setVpnError(""); + try { + const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ allowed_ports: uniquePorts }), + }); + const data = await resp.json().catch(() => ({})); + if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`); + const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : uniquePorts; + const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : vpnDefaultPorts; + applyVpnPorts(allowedPorts, defaultPorts, data?.shell_port || vpnShellPort, data?.source || "custom"); + } catch (err) { + setVpnError(String(err.message || err)); + } finally { + setVpnSaving(false); + } + }, [applyVpnPorts, vpnAgentId, vpnCustomPorts, vpnDefaultPorts, vpnPortGroups, vpnShellPort, vpnToggles]); + + const resetVpnConfig = useCallback(() => { + if (!vpnDefaultPorts.length) { + setVpnError("No default ports available to reset."); + return; + } + setVpnError(""); + applyVpnPorts(vpnDefaultPorts, vpnDefaultPorts, vpnShellPort, "default"); + }, [applyVpnPorts, vpnDefaultPorts, vpnShellPort]); + + useEffect(() => { + const advancedIndex = TOP_TABS.findIndex((item) => item.key === "advanced"); + if (advancedIndex < 0) return; + if (tab !== advancedIndex) return; + if (!vpnAgentId) return; + if (vpnLoadedFor === vpnAgentId) return; + loadVpnConfig(); + }, [loadVpnConfig, tab, vpnAgentId, vpnLoadedFor]); + // No explicit live recap tab; recaps are recorded into Activity History const clearHistory = async () => { @@ -739,6 +888,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage ); const summary = details.summary || {}; + const vpnAgentId = useMemo(() => { + return ( + meta.agentId || + summary.agent_id || + agent?.agent_id || + agent?.id || + device?.agent_id || + device?.agent_guid || + device?.id || + "" + ); + }, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]); + const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]); const tunnelDevice = useMemo( () => ({ ...(device || {}), @@ -876,7 +1038,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage const formatScriptType = useCallback((raw) => { const value = String(raw || "").toLowerCase(); if (value === "ansible") return "Ansible Playbook"; - if (value === "reverse_tunnel") return "Reverse Tunnel"; + if (value === "reverse_tunnel" || value === "vpn_tunnel") return "Reverse VPN Tunnel"; return "Script"; }, []); @@ -1368,6 +1530,150 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage ); + const handleVpnToggle = useCallback((key, checked) => { + setVpnToggles((prev) => ({ ...(prev || {}), [key]: checked })); + setVpnSource("custom"); + }, []); + + const renderAdvancedConfigTab = () => { + const sourceLabel = vpnSource === "custom" ? "Custom overrides" : "Defaults"; + const showProgress = vpnLoading || vpnSaving; + return ( + + + {showProgress ? : null} + + + + Reverse VPN Tunnel - Allowed Ports + + + Toggle which services the Engine can reach over the WireGuard tunnel for this device. + + + + + + + {vpnPortGroups.map((group) => ( + + + + {group.label} + + + {group.description} + + + {group.ports.map((port) => ( + + ))} + + + handleVpnToggle(group.key, event.target.checked)} + color="info" + disabled={vpnLoading || vpnSaving} + /> + + ))} + + {vpnCustomPorts.length ? ( + + + Custom ports preserved: {vpnCustomPorts.join(", ")} + + + ) : null} + {vpnError ? ( + + {vpnError} + + ) : null} + + + + + + + ); + }; + const memoryRows = useMemo( () => (details.memory || []).map((m, idx) => ({ @@ -1618,6 +1924,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage renderNetworkTab, renderSoftware, renderHistory, + renderAdvancedConfigTab, renderRemoteShellTab, ]; const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)(); @@ -1742,7 +2049,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage > {TOP_TABS.map((tabDef) => ( } iconPosition="start" diff --git a/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx index 024b3535..57c03d19 100644 --- a/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx +++ b/Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx @@ -5,17 +5,17 @@ import { Button, Stack, TextField, - MenuItem, IconButton, Tooltip, LinearProgress, + Chip, } from "@mui/material"; import { PlayArrowRounded as PlayIcon, StopRounded as StopIcon, ContentCopy as CopyIcon, RefreshRounded as RefreshIcon, - LanRounded as PortIcon, + LanRounded as IpIcon, LinkRounded as LinkIcon, } from "@mui/icons-material"; import { io } from "socket.io-client"; @@ -24,18 +24,7 @@ import "prismjs/components/prism-powershell"; import "prismjs/themes/prism-okaidia.css"; import Editor from "react-simple-code-editor"; -// Console diagnostics for troubleshooting the connect/disconnect flow. -const debugLog = (...args) => { - try { - // eslint-disable-next-line no-console - console.error("[ReverseTunnel][PS]", ...args); - } catch { - // ignore - } -}; - const MAGIC_UI = { - panelBg: "rgba(7,11,24,0.92)", panelBorder: "rgba(148, 163, 184, 0.35)", textMuted: "#94a3b8", textBright: "#e2e8f0", @@ -56,13 +45,25 @@ const gradientButtonSx = { }, }; -const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32) -const MSG_CLOSE = 0x08; -const CLOSE_AGENT_SHUTDOWN = 6; - const fontFamilyMono = 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; +const emitAsync = (socket, event, payload, timeoutMs = 4000) => + new Promise((resolve) => { + let settled = false; + const timer = setTimeout(() => { + if (settled) return; + settled = true; + resolve({ error: "timeout" }); + }, timeoutMs); + socket.emit(event, payload, (resp) => { + if (settled) return; + settled = true; + clearTimeout(timer); + resolve(resp || {}); + }); + }); + function normalizeText(value) { if (value == null) return ""; try { @@ -72,28 +73,6 @@ function normalizeText(value) { } } -function base64FromBytes(bytes) { - let binary = ""; - bytes.forEach((b) => { - binary += String.fromCharCode(b); - }); - return btoa(binary); -} - -function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") { - const payload = new TextEncoder().encode(JSON.stringify({ code, reason })); - const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length); - const view = new DataView(buffer); - view.setUint8(0, 1); // version - view.setUint8(1, MSG_CLOSE); - view.setUint8(2, 0); // flags - view.setUint8(3, 0); // reserved - view.setUint32(4, channelId >>> 0, true); - view.setUint32(8, payload.length >>> 0, true); - new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload); - return base64FromBytes(new Uint8Array(buffer)); -} - function highlightPs(code) { try { return Prism.highlight(code || "", Prism.languages.powershell, "powershell"); @@ -102,52 +81,18 @@ function highlightPs(code) { } } -const INITIAL_MILESTONES = { - tunnelReady: false, - operatorAttached: false, - shellEstablished: false, -}; -const INITIAL_STATUS_CHAIN = ["Offline"]; - export default function ReverseTunnelPowershell({ device }) { - const [connectionType, setConnectionType] = useState("ps"); - const [tunnel, setTunnel] = useState(null); const [sessionState, setSessionState] = useState("idle"); - const [, setStatusMessage] = useState(""); - const [, setStatusSeverity] = useState("info"); + const [shellState, setShellState] = useState("idle"); + const [tunnel, setTunnel] = useState(null); const [output, setOutput] = useState(""); const [input, setInput] = useState(""); + const [statusMessage, setStatusMessage] = useState(""); const [copyFlash, setCopyFlash] = useState(false); - const [, setPolling] = useState(false); - const [psStatus, setPsStatus] = useState({}); - const [milestones, setMilestones] = useState(() => ({ ...INITIAL_MILESTONES })); - const [tunnelSteps, setTunnelSteps] = useState(() => [...INITIAL_STATUS_CHAIN]); - const [websocketSteps, setWebsocketSteps] = useState(() => [...INITIAL_STATUS_CHAIN]); - const [shellSteps, setShellSteps] = useState(() => [...INITIAL_STATUS_CHAIN]); + const [loading, setLoading] = useState(false); const socketRef = useRef(null); - const pollTimerRef = useRef(null); - const resizeTimerRef = useRef(null); + const localSocketRef = useRef(false); const terminalRef = useRef(null); - const joinRetryRef = useRef(null); - const joinAttemptsRef = useRef(0); - const tunnelRef = useRef(null); - const shellFlagsRef = useRef({ openSent: false, ack: false }); - const DOMAIN_REMOTE_SHELL = "remote-interactive-shell"; - - useEffect(() => { - debugLog("component mount", { hostname: device?.hostname, agentId }); - return () => debugLog("component unmount"); - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []); - - const hostname = useMemo(() => { - return ( - normalizeText(device?.hostname) || - normalizeText(device?.summary?.hostname) || - normalizeText(device?.agent_hostname) || - "" - ); - }, [device]); const agentId = useMemo(() => { return ( @@ -162,78 +107,18 @@ export default function ReverseTunnelPowershell({ device }) { ); }, [device]); - const appendStatus = useCallback((setter, label) => { - if (!label) return; - setter((prev) => { - const next = [...prev, label]; - const cap = 6; - return next.length > cap ? next.slice(next.length - cap) : next; - }); - }, []); - - const resetState = useCallback(() => { - debugLog("resetState invoked"); - setTunnel(null); - setSessionState("idle"); - setStatusMessage(""); - setStatusSeverity("info"); - setOutput(""); - setInput(""); - setPsStatus({}); - setMilestones({ ...INITIAL_MILESTONES }); - setTunnelSteps([...INITIAL_STATUS_CHAIN]); - setWebsocketSteps([...INITIAL_STATUS_CHAIN]); - setShellSteps([...INITIAL_STATUS_CHAIN]); - shellFlagsRef.current = { openSent: false, ack: false }; - }, []); - - useEffect(() => { - tunnelRef.current = tunnel?.tunnel_id || null; - }, [tunnel?.tunnel_id]); - - const disconnectSocket = useCallback(() => { - const socket = socketRef.current; - if (socket) { - socket.off(); - socket.disconnect(); + const ensureSocket = useCallback(() => { + if (socketRef.current) return socketRef.current; + const existing = typeof window !== "undefined" ? window.BorealisSocket : null; + if (existing) { + socketRef.current = existing; + localSocketRef.current = false; + return existing; } - socketRef.current = null; - }, []); - - const stopPolling = useCallback(() => { - if (pollTimerRef.current) { - clearTimeout(pollTimerRef.current); - pollTimerRef.current = null; - } - setPolling(false); - }, []); - - const stopTunnel = useCallback(async (reason = "operator_disconnect", tunnelIdOverride = null) => { - const tunnelId = tunnelIdOverride || tunnelRef.current; - if (!tunnelId) return; - try { - await fetch(`/api/tunnel/${tunnelId}`, { - method: "DELETE", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ reason }), - }); - } catch (err) { - // best-effort; socket close frame acts as fallback - } - }, []); - - useEffect(() => { - return () => { - debugLog("cleanup on unmount", { tunnelId: tunnelRef.current }); - stopPolling(); - disconnectSocket(); - if (joinRetryRef.current) { - clearTimeout(joinRetryRef.current); - joinRetryRef.current = null; - } - stopTunnel("component_unmount", tunnelRef.current); - }; - // eslint-disable-next-line react-hooks/exhaustive-deps + const socket = io(window.location.origin, { transports: ["websocket"] }); + socketRef.current = socket; + localSocketRef.current = true; + return socket; }, []); const appendOutput = useCallback((text) => { @@ -257,6 +142,137 @@ export default function ReverseTunnelPowershell({ device }) { scrollToBottom(); }, [output, scrollToBottom]); + const stopTunnel = useCallback( + async (reason = "operator_disconnect") => { + if (!agentId) return; + try { + await fetch("/api/tunnel/disconnect", { + method: "DELETE", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }), + }); + } catch { + // best-effort + } + }, + [agentId, tunnel?.tunnel_id] + ); + + const closeShell = useCallback(async () => { + const socket = ensureSocket(); + await emitAsync(socket, "vpn_shell_close", {}); + }, [ensureSocket]); + + const handleDisconnect = useCallback(async () => { + setLoading(true); + setStatusMessage(""); + try { + await closeShell(); + await stopTunnel("operator_disconnect"); + } finally { + setTunnel(null); + setShellState("closed"); + setSessionState("idle"); + setLoading(false); + } + }, [closeShell, stopTunnel]); + + useEffect(() => { + const socket = ensureSocket(); + const handleDisconnectEvent = () => { + if (sessionState === "connected") { + setShellState("closed"); + setSessionState("idle"); + setStatusMessage("Socket disconnected."); + } + }; + const handleOutput = (payload) => { + appendOutput(payload?.data || ""); + }; + const handleClosed = () => { + setShellState("closed"); + setSessionState("idle"); + setStatusMessage("Shell closed."); + }; + + socket.on("disconnect", handleDisconnectEvent); + socket.on("vpn_shell_output", handleOutput); + socket.on("vpn_shell_closed", handleClosed); + + return () => { + socket.off("disconnect", handleDisconnectEvent); + socket.off("vpn_shell_output", handleOutput); + socket.off("vpn_shell_closed", handleClosed); + if (localSocketRef.current) { + socket.disconnect(); + } + }; + }, [appendOutput, ensureSocket, sessionState]); + + useEffect(() => { + return () => { + closeShell(); + stopTunnel("component_unmount"); + }; + }, [closeShell, stopTunnel]); + + const requestTunnel = useCallback(async () => { + if (!agentId) { + setStatusMessage("Agent ID is required to connect."); + return; + } + setLoading(true); + setStatusMessage(""); + setSessionState("connecting"); + setShellState("opening"); + try { + const resp = await fetch("/api/tunnel/connect", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ agent_id: agentId }), + }); + const data = await resp.json().catch(() => ({})); + if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`); + const statusResp = await fetch( + `/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1` + ); + const statusData = await statusResp.json().catch(() => ({})); + if (!statusResp.ok || statusData?.status !== "up") { + throw new Error(statusData?.error || "Tunnel not ready"); + } + setTunnel({ ...data, ...statusData }); + + const socket = ensureSocket(); + const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000); + if (openResp?.error) { + throw new Error(openResp.error); + } + setSessionState("connected"); + setShellState("connected"); + } catch (err) { + setSessionState("error"); + setShellState("closed"); + setStatusMessage(String(err.message || err)); + } finally { + setLoading(false); + } + }, [agentId, ensureSocket]); + + const handleSend = useCallback( + async (text) => { + const socket = ensureSocket(); + if (!socket || sessionState !== "connected") return; + const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`; + appendOutput(`\nPS> ${text}\n`); + setInput(""); + const resp = await emitAsync(socket, "vpn_shell_send", { data: payload }); + if (resp?.error) { + setStatusMessage("Send failed."); + } + }, + [appendOutput, ensureSocket, sessionState] + ); + const handleCopy = async () => { try { await navigator.clipboard.writeText(output || ""); @@ -267,329 +283,7 @@ export default function ReverseTunnelPowershell({ device }) { } }; - const measureTerminal = useCallback(() => { - const el = terminalRef.current; - if (!el) return { cols: 120, rows: 32 }; - const width = el.clientWidth || 960; - const height = el.clientHeight || 460; - const charWidth = 8.2; - const charHeight = 18; - const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300)); - const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200)); - return { cols, rows }; - }, []); - - const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => { - return new Promise((resolve) => { - let settled = false; - const timer = setTimeout(() => { - if (settled) return; - settled = true; - resolve({ error: "timeout" }); - }, timeoutMs); - socket.emit(event, payload, (resp) => { - if (settled) return; - settled = true; - clearTimeout(timer); - resolve(resp || {}); - }); - }); - }, []); - - const pollLoop = useCallback( - (socket, tunnelId) => { - if (!socket || !tunnelId) return; - debugLog("pollLoop tick", { tunnelId }); - setPolling(true); - pollTimerRef.current = setTimeout(async () => { - const resp = await emitAsync(socket, "ps_poll", {}); - if (resp?.error) { - debugLog("pollLoop error", resp); - stopPolling(); - disconnectSocket(); - setPsStatus({}); - setTunnel(null); - setSessionState("error"); - return; - } - if (Array.isArray(resp?.output) && resp.output.length) { - appendOutput(resp.output.join("")); - } - if (resp?.status) { - setPsStatus(resp.status); - debugLog("pollLoop status", resp.status); - if (resp.status.closed) { - setSessionState("closed"); - setTunnel(null); - setMilestones({ ...INITIAL_MILESTONES }); - appendStatus(setShellSteps, "Shell closed"); - appendStatus(setTunnelSteps, "Stopped"); - appendStatus(setWebsocketSteps, "Relay stopped"); - shellFlagsRef.current = { openSent: false, ack: false }; - stopPolling(); - return; - } - if (resp.status.open_sent && !shellFlagsRef.current.openSent) { - appendStatus(setShellSteps, "Opening remote shell"); - shellFlagsRef.current.openSent = true; - } - if (resp.status.ack && !shellFlagsRef.current.ack) { - setSessionState("connected"); - setMilestones((prev) => ({ ...prev, shellEstablished: true })); - appendStatus(setShellSteps, "Remote shell established"); - shellFlagsRef.current.ack = true; - } - } - pollLoop(socket, tunnelId); - }, 520); - }, - [appendOutput, emitAsync, stopPolling, disconnectSocket, appendStatus] - ); - - const handleDisconnect = useCallback( - async (reason = "operator_disconnect") => { - debugLog("handleDisconnect begin", { reason, tunnelId: tunnel?.tunnel_id, psStatus, sessionState }); - setPsStatus({}); - const socket = socketRef.current; - const tunnelId = tunnel?.tunnel_id; - if (joinRetryRef.current) { - clearTimeout(joinRetryRef.current); - joinRetryRef.current = null; - } - joinAttemptsRef.current = 0; - if (socket && tunnelId) { - const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close"); - debugLog("emit CLOSE", { tunnelId }); - socket.emit("send", { frame }); - } - await stopTunnel(reason); - debugLog("stopTunnel issued", { tunnelId }); - stopPolling(); - disconnectSocket(); - setTunnel(null); - setSessionState("closed"); - setMilestones({ ...INITIAL_MILESTONES }); - appendStatus(setTunnelSteps, "Stopped"); - appendStatus(setWebsocketSteps, "Relay closed"); - appendStatus(setShellSteps, "Shell closed"); - shellFlagsRef.current = { openSent: false, ack: false }; - debugLog("handleDisconnect finished", { tunnelId }); - }, - [appendStatus, disconnectSocket, stopPolling, stopTunnel, tunnel?.tunnel_id] - ); - - const handleResize = useCallback(() => { - if (!socketRef.current || sessionState === "idle") return; - const dims = measureTerminal(); - socketRef.current.emit("ps_resize", dims); - }, [measureTerminal, sessionState]); - - useEffect(() => { - const observer = - typeof ResizeObserver !== "undefined" - ? new ResizeObserver(() => { - if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current); - resizeTimerRef.current = setTimeout(() => handleResize(), 200); - }) - : null; - const el = terminalRef.current; - if (observer && el) observer.observe(el); - const onWinResize = () => handleResize(); - window.addEventListener("resize", onWinResize); - return () => { - window.removeEventListener("resize", onWinResize); - if (observer && el) observer.unobserve(el); - }; - }, [handleResize]); - - const connectSocket = useCallback( - (lease, { isRetry = false } = {}) => { - if (!lease?.tunnel_id) return; - if (joinRetryRef.current) { - clearTimeout(joinRetryRef.current); - joinRetryRef.current = null; - } - if (!isRetry) { - joinAttemptsRef.current = 0; - } - disconnectSocket(); - stopPolling(); - setSessionState("waiting"); - const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket", "polling"] }); - socketRef.current = socket; - - socket.on("connect_error", () => { - debugLog("socket connect_error"); - setStatusSeverity("warning"); - setStatusMessage("Tunnel namespace unavailable."); - setTunnel(null); - setSessionState("error"); - appendStatus(setWebsocketSteps, "Relay connect error"); - }); - - socket.on("disconnect", () => { - debugLog("socket disconnect", { tunnelId: tunnel?.tunnel_id }); - stopPolling(); - if (sessionState !== "closed") { - setSessionState("disconnected"); - setStatusSeverity("warning"); - setStatusMessage("Socket disconnected."); - setTunnel(null); - appendStatus(setWebsocketSteps, "Relay disconnected"); - } - }); - - socket.on("connect", async () => { - debugLog("socket connect", { tunnelId: lease.tunnel_id }); - setMilestones((prev) => ({ ...prev, operatorAttached: true })); - setStatusSeverity("info"); - setStatusMessage("Joining tunnel..."); - appendStatus(setWebsocketSteps, "Relay connected"); - const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id }, 5000); - if (joinResp?.error) { - const attempt = (joinAttemptsRef.current += 1); - const isTimeout = joinResp.error === "timeout"; - if (joinResp.error === "unknown_tunnel") { - setSessionState("waiting_agent"); - setStatusSeverity("info"); - setStatusMessage("Waiting for agent to establish tunnel..."); - appendStatus(setWebsocketSteps, "Waiting for agent"); - } else if (isTimeout || joinResp.error === "attach_failed") { - setSessionState("waiting_agent"); - setStatusSeverity("warning"); - setStatusMessage("Tunnel join timed out. Retrying..."); - appendStatus(setWebsocketSteps, `Join retry ${attempt}`); - } else { - debugLog("join error", joinResp); - setSessionState("error"); - setStatusSeverity("error"); - setStatusMessage(joinResp.error); - appendStatus(setWebsocketSteps, `Join failed: ${joinResp.error}`); - return; - } - if (attempt <= 5) { - joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 800); - } else { - setSessionState("error"); - setTunnel(null); - setStatusSeverity("warning"); - setStatusMessage("Operator could not attach to tunnel. Try Connect again."); - appendStatus(setWebsocketSteps, "Join failed after retries"); - } - return; - } - appendStatus(setWebsocketSteps, "Relay joined"); - const dims = measureTerminal(); - debugLog("ps_open emit", { tunnelId: lease.tunnel_id, dims }); - const openResp = await emitAsync(socket, "ps_open", dims, 5000); - if (openResp?.error && openResp.error === "ps_unsupported") { - // Suppress warming message; channel will settle once agent attaches. - } - if (!shellFlagsRef.current.openSent) { - appendStatus(setShellSteps, "Opening remote shell"); - shellFlagsRef.current.openSent = true; - } - appendOutput(""); - setSessionState("waiting_agent"); - pollLoop(socket, lease.tunnel_id); - handleResize(); - }); - }, - [appendOutput, appendStatus, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling] - ); - - const requestTunnel = useCallback(async () => { - if (tunnel && sessionState !== "closed" && sessionState !== "idle") { - setStatusSeverity("info"); - setStatusMessage(""); - connectSocket(tunnel); - return; - } - debugLog("requestTunnel", { agentId, connectionType }); - if (!agentId) { - setStatusSeverity("warning"); - setStatusMessage("Agent ID is required to request a tunnel."); - return; - } - if (connectionType !== "ps") { - setStatusSeverity("warning"); - setStatusMessage("Only PowerShell is supported right now."); - return; - } - resetState(); - setSessionState("requesting"); - setStatusSeverity("info"); - setStatusMessage(""); - appendStatus(setTunnelSteps, "Requesting lease"); - try { - const resp = await fetch("/api/tunnel/request", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: DOMAIN_REMOTE_SHELL }), - }); - const data = await resp.json().catch(() => ({})); - if (!resp.ok) { - const err = data?.error || `HTTP ${resp.status}`; - setSessionState("error"); - setStatusSeverity(err === "domain_limit" ? "warning" : "error"); - setStatusMessage(""); - return; - } - setMilestones((prev) => ({ ...prev, tunnelReady: true })); - setTunnel(data); - setStatusMessage(""); - setSessionState("lease_issued"); - appendStatus( - setTunnelSteps, - data?.tunnel_id - ? `Lease issued (${data.tunnel_id.slice(0, 8)} @ Port ${data.port || "-"})` - : "Lease issued" - ); - connectSocket(data); - } catch (e) { - setSessionState("error"); - setStatusSeverity("error"); - setStatusMessage(""); - appendStatus(setTunnelSteps, "Lease request failed"); - } - }, [DOMAIN_REMOTE_SHELL, agentId, appendStatus, connectSocket, connectionType, resetState]); - - const handleSend = useCallback( - async (text) => { - const socket = socketRef.current; - if (!socket) return; - const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`; - appendOutput(`\nPS> ${text}\n`); - setInput(""); - const resp = await emitAsync(socket, "ps_send", { data: payload }); - if (resp?.error) { - setStatusSeverity("warning"); - setStatusMessage(""); - } - }, - [appendOutput, emitAsync] - ); - - const isConnected = sessionState === "connected" || (psStatus?.ack && !psStatus?.closed); - const isClosed = sessionState === "closed" || psStatus?.closed; - const isBusy = - sessionState === "requesting" || - sessionState === "waiting" || - sessionState === "waiting_agent" || - sessionState === "lease_issued"; - const canStart = Boolean(agentId) && !isBusy; - - useEffect(() => { - const handleUnload = () => { - stopTunnel("window_unload"); - }; - if (tunnel?.tunnel_id) { - window.addEventListener("beforeunload", handleUnload); - return () => window.removeEventListener("beforeunload", handleUnload); - } - return undefined; - }, [stopTunnel, tunnel?.tunnel_id]); - + const isConnected = sessionState === "connected"; const sessionChips = [ tunnel?.tunnel_id ? { @@ -598,58 +292,43 @@ export default function ReverseTunnelPowershell({ device }) { icon: , } : null, - tunnel?.port + tunnel?.virtual_ip ? { - label: `Port ${tunnel.port}`, + label: `IP ${String(tunnel.virtual_ip).split("/")[0]}`, color: MAGIC_UI.accentA, - icon: , + icon: , } : null, ].filter(Boolean); return ( - - + - - + {isConnected ? "Disconnect" : "Connect"} + + + {sessionChips.map((chip) => ( + + ))} - + - {isBusy ? : null} + {loading ? : null} setInput(e.target.value)} onKeyDown={(e) => { if (e.key === "Enter" && !e.shiftKey) { @@ -753,43 +428,19 @@ export default function ReverseTunnelPowershell({ device }) { /> - - - Tunnel:{" "} - - {tunnelSteps.join(" > ")} - + + + + Tunnel: {sessionState === "connected" ? "Active" : sessionState} - - Websocket:{" "} - - {websocketSteps.join(" > ")} - + + Shell: {shellState === "connected" ? "Ready" : shellState} - - Remote Shell:{" "} - - {shellSteps.join(" > ")} + {statusMessage ? ( + + {statusMessage} - + ) : null} ); diff --git a/Docs/Codex/BOREALIS_AGENT.md b/Docs/Codex/BOREALIS_AGENT.md index 65c13e47..476736df 100644 --- a/Docs/Codex/BOREALIS_AGENT.md +++ b/Docs/Codex/BOREALIS_AGENT.md @@ -20,8 +20,9 @@ Use this doc for agent-only work (Borealis agent runtime under `Data/Agent` → - Validates script payloads with backend-issued Ed25519 signatures before execution. - Outbound-only; API/WebSocket calls flow through `AgentHttpClient.ensure_authenticated` for proactive refresh. Logs bootstrap, enrollment, token refresh, and signature events in `Agent/Logs/`. -## Reverse Tunnels -- Design, orchestration, domains, limits, and lifecycle are documented in `Docs/Codex/REVERSE_TUNNELS.md`. Agent role implementation lives in `Data/Agent/Roles/role_ReverseTunnel.py` with per-domain protocol handlers under `Data/Agent/Roles/Reverse_Tunnels/`. +## Reverse VPN Tunnels +- WireGuard reverse VPN design and lifecycle live in `Docs/Codex/REVERSE_TUNNELS.md` and `Docs/Codex/Reverse_VPN_Tunnel_Deployment.md`. +- Agent roles: `Data/Agent/Roles/role_WireGuardTunnel.py` (tunnel lifecycle) and `Data/Agent/Roles/role_VpnShell.py` (VPN PowerShell TCP server). ## Execution Contexts & Roles - Auto-discovers roles from `Data/Agent/Roles/`; no loader changes needed. diff --git a/Docs/Codex/BOREALIS_ENGINE.md b/Docs/Codex/BOREALIS_ENGINE.md index 694f9417..6ed89cb6 100644 --- a/Docs/Codex/BOREALIS_ENGINE.md +++ b/Docs/Codex/BOREALIS_ENGINE.md @@ -23,9 +23,10 @@ Use this doc for Engine work (successor to the legacy server). For shared guidan - Enrollment: operator approvals, conflict detection, auditor recording, pruning of expired codes/refresh tokens. - Background jobs and service adapters maintain compatibility with legacy DB schemas while enabling gradual API takeover. -## Reverse Tunnels -- Full design and lifecycle are in `Docs/Codex/REVERSE_TUNNELS.md` (domains, limits, framing, APIs, stop path, UI hooks). -- Engine orchestrator is `Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py` with domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`. +## Reverse VPN Tunnels +- WireGuard reverse VPN design and lifecycle live in `Docs/Codex/REVERSE_TUNNELS.md` and `Docs/Codex/Reverse_VPN_Tunnel_Deployment.md`. +- Engine orchestrator: `Data/Engine/services/VPN/vpn_tunnel_service.py` with WireGuard manager `Data/Engine/services/VPN/wireguard_server.py`. +- UI shell bridge: `Data/Engine/services/WebSocket/vpn_shell.py`. ## WebUI & WebSocket Migration - Static/template handling: `Data/Engine/services/WebUI`; deployment copy paths are wired through `Borealis.ps1` with TLS-aware URL generation. diff --git a/Docs/Codex/REVERSE_TUNNELS.md b/Docs/Codex/REVERSE_TUNNELS.md index 86e89cea..3c72d647 100644 --- a/Docs/Codex/REVERSE_TUNNELS.md +++ b/Docs/Codex/REVERSE_TUNNELS.md @@ -1,92 +1,61 @@ -# Borealis Reverse Tunnels – Operator & Developer Guide +# Borealis Reverse VPN Tunnels (WireGuard) – Operator & Developer Guide -This document is the single reference for how Borealis reverse tunnels are organized, secured, and orchestrated. It is written for Codex agents extending the feature (new protocols, UI, or policy changes). +This document is the reference for Borealis reverse VPN tunnels built on WireGuard. The legacy WebSocket framing and domain-lane tunnel stack has been retired; the system now uses a single outbound WireGuard tunnel per agent with host-only routing and per-device ACLs. ## 1) High-Level Model -- Outbound-only: Agents initiate all tunnel sockets to the Engine. No inbound openings on devices. -- Transport: WebSocket-over-TLS carrying a binary frame header (version | msg_type | flags | reserved | channel_id | length) plus payload. -- Leases: Engine issues short-lived leases per agent/domain/protocol. Each lease binds a tunnel_id to an ephemeral Engine port and a signed token. -- Domains: Concurrency “lanes” keep protocols isolated: `remote-interactive-shell` (2), `remote-management` (1), `remote-video` (2). Legacy aliases (`ps`, etc.) normalize into these lanes. -- Channels: Logical streams inside a tunnel (channel_id u32). PS uses channel 1; future protocols can open more channels per tunnel as needed. -- Tear-down: Idle/grace timeouts plus explicit operator stop. Closing a tunnel must close its protocol channel(s) and kill the agent process for interactive shells. +- Outbound-only: agents establish WireGuard tunnels to the Engine; no inbound access on devices. +- Transport: WireGuard/UDP on port 30000. +- Sessions: one live VPN tunnel per agent; multiple operators share it. +- Routing: host-only /32 per agent; AllowedIPs restricted to the agent /32 and engine /32; no client-to-client. +- Idle timeout: 15 minutes of no operator activity; no grace period. +- Keys: WireGuard server keys under `Engine/Certificates/VPN_Server`; client keys under `Agent/Borealis/Certificates/VPN_Client`. ## 2) Engine Components -- Orchestrator: `Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py` - - Lease manager: Port pool allocator, domain limit enforcement, idle/grace sweeper. - - Token issuer/validator: Binds agent_id, tunnel_id, domain, protocol, port, expires_at. - - Bridge: Maps agent sockets ↔ operator sockets; stores per-tunnel protocol server instances. - - Logging: `Engine/Logs/reverse_tunnel.log` plus Device Activity start/stop entries. - - Stop path: `stop_tunnel` closes protocol servers, emits `reverse_tunnel_stop` to agents, releases lease/bridge. -- Protocol registry: Domain/protocol handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`: - - `remote_interactive_shell/Protocols/Powershell.py` (live), `Bash.py` (placeholder). - - `remote_management/Protocols/SSH.py`, `WinRM.py` (placeholders). - - `remote_video/Protocols/VNC.py`, `RDP.py`, `WebRTC.py` (placeholders). -- API Endpoints: - - `POST /api/tunnel/request` → allocates lease, returns {tunnel_id, port, token, idle_seconds, grace_seconds, domain, protocol}. - - `DELETE /api/tunnel/` → operator-driven stop; pushes stop to agent and releases the lease. - - Domain default for PowerShell requests is `remote-interactive-shell` (legacy `ps` still accepted). -- Operator Socket.IO namespace `/tunnel`: - - `join`, `send`, `poll`, `ps_open`, `ps_send`, `ps_resize`, `ps_poll`. - - Operator socket disconnect triggers `stop_tunnel` if no other operators remain attached. -- WebUI (current): `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` requests PS leases in `remote-interactive-shell`, sends CLOSE frames, and calls DELETE on disconnect/unload. +- Orchestrator: `Data/Engine/services/VPN/vpn_tunnel_service.py` + - Allocates per-agent /32, issues short-lived orchestration tokens, enforces single-session. + - Starts/stops WireGuard listener, applies firewall rules, idles out on inactivity. + - Emits Socket.IO events: `vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`. +- WireGuard manager: `Data/Engine/services/VPN/wireguard_server.py` + - Generates server keys, renders config, manages `wireguard.exe` tunnel service, applies ACL rules. +- PowerShell bridge: `Data/Engine/services/WebSocket/vpn_shell.py` + - Proxies UI shell input/output to the agent’s TCP shell server over WireGuard. +- Logging: `Engine/Logs/reverse_tunnel.log` plus Device Activity entries. -## 3) Agent Components -- Role: `Data/Agent/Roles/role_ReverseTunnel.py` - - Validates signed lease tokens; enforces domain limits (2/1/2 with legacy fallbacks). - - Outbound TLS WS connect to assigned port; heartbeats + idle/grace watchdog; stop_all closes channels and sends CLOSE. - - Protocol registry: loads handlers from `Data/Agent/Roles/Reverse_Tunnels/*/Protocols/*` (PowerShell live; others stubbed to close unsupported channels cleanly). -- PowerShell channel: `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` (pipes-only, no PTY); re-exported under `Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py`. -- Logging: `Agent/Logs/reverse_tunnel.log` with channel/tunnel lifecycle. +## 3) API Endpoints +- `POST /api/tunnel/connect` → issues session material (tunnel_id, token, virtual_ip, endpoint, allowed_ports, idle_seconds). +- `GET /api/tunnel/status` → returns up/down status for an agent. +- `GET /api/tunnel/connect/status` → alias for status (used by UI before shell open). +- `DELETE /api/tunnel/disconnect` → immediate teardown (agent + engine cleanup). +- `GET /api/device/vpn_config/` → read per-agent allowed ports. +- `PUT /api/device/vpn_config/` → update allowed ports. -## 4) Framing, Heartbeats, Close -- Header: version(1) | msg_type(1) | flags(1) | reserved(1) | channel_id(u32 LE) | length(u32 LE). -- Messages: CONNECT/ACK, CHANNEL_OPEN/ACK, DATA, CONTROL (resize), WINDOW_UPDATE (reserved), HEARTBEAT (ping/pong), CLOSE. -- Close codes: ok, idle_timeout, grace_expired, protocol_error, auth_failed, server_shutdown, agent_shutdown, domain_limit, unexpected_disconnect. -- Heartbeats: Engine → Agent loop; idle/grace sweeper ~15s on Engine; Agent watchdog closes on idle/grace. +## 4) Agent Components +- Tunnel lifecycle: `Data/Agent/Roles/role_WireGuardTunnel.py` + - Validates orchestration tokens, starts/stops WireGuard client service, enforces idle. +- Shell server: `Data/Agent/Roles/role_VpnShell.py` + - TCP PowerShell server bound to `0.0.0.0:47001`, restricted to VPN subnet (10.255.x.x). +- Logging: `Agent/Logs/reverse_tunnel.log`. -## 5) Lifecycle (PowerShell example) -1. UI calls `POST /api/tunnel/request` with agent_id, protocol=ps, domain=remote-interactive-shell. -2. Engine allocates port/tunnel_id, signs token, starts listener, pushes `reverse_tunnel_start` to agent. -3. Agent dials WS to assigned port, sends CONNECT with token. Engine validates, binds bridge, sends CONNECT_ACK + heartbeat. -4. Operator Socket.IO `/tunnel` joins; Engine attaches operator, instantiates PS server, issues CHANNEL_OPEN. -5. Agent launches PowerShell (pipes), streams stdout/stderr as DATA; operator input via `ps_send`; optional resize via `ps_resize` (no-op on agent pipes). -6. On operator Disconnect/tab close, UI sends CLOSE frame and calls DELETE; Engine stop path notifies agent (`reverse_tunnel_stop`), closes channel, releases lease/domain slot. -7. Idle/grace expiry or agent disconnect also triggers close/release; domain slots free immediately. +## 5) Security & Auth +- TLS pinned for Engine API/Socket.IO. +- Orchestration tokens signed via Engine Ed25519 key; agent verifies signatures and stores the signing key. +- WireGuard AllowedIPs /32; no LAN routes; client-to-client blocked. +- Engine firewall rules enforce per-device allowed ports. -## 6) Security & Auth -- TLS: Reuse existing pinned bundle; outbound-only agent sockets. -- Token: short-lived, binds agent_id/tunnel_id/domain/protocol/port/expires_at; optional signature verification (Ed25519 signer when configured). -- Operator auth: uses existing Engine session/cookie/bearer for `/tunnel` namespace and API endpoints. +## 6) UI +- Device details now include an “Advanced Config” tab for per-device allowed ports. +- PowerShell MVP reuses `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` with WireGuard APIs + VPN shell events. -## 7) Configuration Knobs (defaults) -- Port pool: 30000–40000; fixed port optional (context settings). -- Idle timeout: 3600s; Grace timeout: 3600s. -- Heartbeat interval: 20s (Engine → Agent). -- Domain limits: remote-interactive-shell=2, remote-management=1, remote-video=2; legacy aliases preserved. -- Log path: `Engine/Logs/reverse_tunnel.log`; `Agent/Logs/reverse_tunnel.log`. +## 7) Extending to New Protocols +- Add protocol ports to the device allowlist and UI toggles. +- Reuse the existing VPN tunnel; no new transport/domain lanes required. -## 8) Logs & Telemetry -- Engine: lease events, socket events, close reasons in `reverse_tunnel.log`; Device Activity start/stop with tunnel_id/operator_id when available. -- Agent: role lifecycle, channel start/stop, errors in `reverse_tunnel.log`. +## 8) Legacy Removal +- WebSocket tunnel domains, protocol handlers, and domain limits are removed. +- No `/tunnel` Socket.IO namespace or framed protocol messages remain. -## 9) Extending to New Protocols -- Add Engine handler under the appropriate domain folder and register in the orchestrator’s protocol registry. -- Add Agent handler under matching domain folder; update role registry to load it. -- Define channel open semantics (metadata), DATA/CONTROL usage, and close behavior. -- Update API/UI to allow selecting the protocol/domain and to send protocol-specific controls. - -## 10) Outstanding Work -- Implement real handlers for Bash/SSH/WinRM/RDP/VNC/WebRTC and surface in UI. -- Add tests for DELETE stop path, per-domain limits, and browser disconnect cleanup. -- Consider a binary WebSocket browser bridge to replace Socket.IO for high-throughput protocols. - -## 11) Risks & Watchpoints -- Eventlet/asyncio coexistence: tunnel loop runs on its own thread/loop; avoid blocking Socket.IO handlers. -- Port exhaustion: handle allocation failures cleanly; always release on stop/idle/grace. -- Buffer growth: add back-pressure before enabling high-throughput protocols. -- Security: strict token binding (agent_id/tunnel_id/domain/protocol/port/expiry) and TLS; reject framing errors. - -## 12) Change Log (not exhaustive) -- 2025-11-30: Initial scaffold (lease manager, framing, tokens, API, Agent role, PS handlers). -- 2025-12-06: Simplified PS to pipes-only; improved handler imports; UI status tweaks. -- 2025-12-18: Domain lanes introduced (`remote-interactive-shell`, `remote-management`, `remote-video`) with limits 2/1/2; protocol handlers reorganized under `Reverse_Tunnels/*/Protocols/*`; orchestrator renamed to `reverse_tunnel_orchestrator.py`; explicit stop API/Socket.IO cleanup; WebUI Disconnect/unload calls DELETE + CLOSE for immediate teardown. +## 9) Change Log (not exhaustive) +- 2025-11-30: Legacy WebSocket tunnel scaffold introduced (lease manager, framing, tokens). +- 2025-12-06: Legacy PowerShell handler simplified to pipes-only; UI status tweaks. +- 2025-12-18: Legacy domain lanes added (`remote-interactive-shell`, `remote-management`, `remote-video`) with limits. +- 2025-12-20: WireGuard reverse VPN migration complete; legacy WebSocket tunnels retired; VPN shell bridge + new APIs. diff --git a/Docs/Codex/Reverse_VPN_Tunnel_Deployment.md b/Docs/Codex/Reverse_VPN_Tunnel_Deployment.md index 5ed51ad6..41acec6c 100644 --- a/Docs/Codex/Reverse_VPN_Tunnel_Deployment.md +++ b/Docs/Codex/Reverse_VPN_Tunnel_Deployment.md @@ -42,8 +42,8 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi - [x] Do not start any tunnel yet. - Linux: do nothing yet (see later section). - Checkpoint tests: - - [x] WireGuard binaries available in agent runtime. - - [x] WireGuard driver installed and visible. + - [ ] WireGuard binaries available in agent runtime. + - [ ] WireGuard driver installed and visible. ### 2) Engine VPN Server & ACLs — Milestone: Engine VPN Server & ACLs (Windows) - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). @@ -54,15 +54,15 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi - [x] Do not push DNS or LAN routes; host-only reachability engine IP ↔ agent virtual /32. - ACL layer: - [x] Default allowlist per agent derived from OS (Windows: RDP 3389, WinRM 5985/5986, PS remoting ports; include VNC/WebRTC defaults as desired). - - [x] Allow operator overrides per agent; enforce at engine firewall layer. (rule plans produced; application wiring pending) + - [x] Allow operator overrides per agent; enforce at engine firewall layer. - Keys/Certs: - [x] Prefer reusing existing Engine cert infrastructure for signing orchestration tokens. Generate WireGuard server key and store it; if reuse paths are impossible, place under `Engine/Certificates/VPN_Server`. - [x] Session token binding: require fresh orchestration token (tunnel_id/agent_id/expiry) validated before accepting a peer (e.g., via pre-shared keys or control-plane validation before adding peer). - Logging: server logs to `Engine/Logs/reverse_tunnel.log` (or renamed consistently). [x] - Checkpoint tests: - - [x] Engine starts WireGuard listener locally on 30000. - - [x] Only engine IP reachable; client-to-client blocked. - - [x] Peers without valid token/key are rejected. + - [ ] Engine starts WireGuard listener locally on 30000. + - [ ] Only engine IP reachable; client-to-client blocked. + - [ ] Peers without valid token/key are rejected. ### 3) Agent VPN Client & Lifecycle — Milestone: Agent VPN Client & Lifecycle (Windows) - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). @@ -83,60 +83,64 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi - [ ] Idle timeout fires at ~15 minutes of inactivity. ### 4) API & Service Orchestration — Milestone: API & Service Orchestration (Windows) -- Replace legacy tunnel APIs with: - - `POST /api/tunnel/connect` → tunnel_id, token, WG client config (keys, endpoint, allowed IPs), virtual IP, idle_seconds (900). - - `GET /api/tunnel/status` → up/down, virtual IP, connected operators. - - `DELETE /api/tunnel/disconnect` → immediate teardown and lease release. -- Engine orchestrator: - - Manages single tunnel per agent; tracks tunnel_id, virtual IP, token expiry. - - Emits start/stop signals to agent (rename events as needed). - - Cleans peer/routing state on stop. -- Token issuance: short-lived, binds agent_id/tunnel_id/port/expiry; validated before adding peer. -- Remove domain limits; remove channel/protocol handler registry for tunnels. +- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). +- [x] Replace legacy tunnel APIs with: + - [x] `POST /api/tunnel/connect` → tunnel_id, token, WG client config (keys, endpoint, allowed IPs), virtual IP, idle_seconds (900). + - [x] `GET /api/tunnel/status` → up/down, virtual IP, connected operators. + - [x] `DELETE /api/tunnel/disconnect` → immediate teardown and lease release. +- [x] Engine orchestrator: + - [x] Manages single tunnel per agent; tracks tunnel_id, virtual IP, token expiry. + - [x] Emits start/stop signals to agent (rename events as needed). + - [x] Cleans peer/routing state on stop. +- [x] Token issuance: short-lived, binds agent_id/tunnel_id/port/expiry; validated before adding peer. +- [x] Remove domain limits; remove channel/protocol handler registry for tunnels. - Checkpoint tests: - - API happy path: connect → status → disconnect. - - Reject stale/second connect for same agent while active. + - [ ] API happy path: connect → status → disconnect. + - [ ] Reject stale/second connect for same agent while active. ### 5) UI Advanced Config & Operator Flow (PowerShell MVP) — Milestone: UI Advanced Config & Operator Flow (Windows, PowerShell MVP) -- In `Data/Engine/web-interface/src/Devices/Device_Details.jsx`, add “Advanced Config” tab: - - “Reverse VPN Tunnel - Allowed Ports” with toggles per protocol. - - Defaults by OS (Windows: RDP/WinRM/PS; All: VNC/WebRTC; allow operator overrides). -- PowerShell MVP: - - Reuse `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` as the base UI. - - Rewire to new APIs and virtual IP flow. - - Keep live web terminal behavior (WebSocket or equivalent) so operator input streams to remote PowerShell and outputs stream back in real time over the VPN tunnel. - - Ensure tunnel is up via `/api/tunnel/connect/status` before opening the terminal; call `/api/tunnel/disconnect` on exit/tab close. +- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). +- [x] In `Data/Engine/web-interface/src/Devices/Device_Details.jsx`, add “Advanced Config” tab: + - [x] “Reverse VPN Tunnel - Allowed Ports” with toggles per protocol. + - [x] Defaults by OS (Windows: RDP/WinRM/PS; All: VNC/WebRTC; allow operator overrides). +- [x] PowerShell MVP: + - [x] Reuse `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` as the base UI. + - [x] Rewire to new APIs and virtual IP flow. + - [x] Keep live web terminal behavior (WebSocket or equivalent) so operator input streams to remote PowerShell and outputs stream back in real time over the VPN tunnel. + - [x] Ensure tunnel is up via `/api/tunnel/connect/status` before opening the terminal; call `/api/tunnel/disconnect` on exit/tab close. - Later protocols (RDP/SSH/etc.) can follow once MVP is proven, but do not block on them for this milestone. - Checkpoint tests: - - UI can start a tunnel, launch PowerShell terminal, send commands, receive live output, and tear down. - - Toggles change ACL behavior (engine→agent reachability) as expected. + - [ ] UI can start a tunnel, launch PowerShell terminal, send commands, receive live output, and tear down. + - [ ] Toggles change ACL behavior (engine→agent reachability) as expected. ### 6) Legacy Tunnel Removal & Cleanup — Milestone: Legacy Tunnel Removal & Cleanup (Windows) -- Remove/retire: - - Engine `reverse_tunnel_orchestrator` and domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`. - - Agent `role_ReverseTunnel.py` and protocol handlers. - - WebUI components tied to the old Socket.IO tunnel namespace. -- Update docs and references to point to the new WireGuard VPN flow; keep change log entries. -- Ensure no lingering domain limits/config knobs remain. +- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). +- [x] Remove/retire: + - [x] Engine `reverse_tunnel_orchestrator` and domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`. + - [x] Agent `role_ReverseTunnel.py` and protocol handlers. + - [x] WebUI components tied to the old Socket.IO tunnel namespace. +- [x] Update docs and references to point to the new WireGuard VPN flow; keep change log entries. +- [x] Ensure no lingering domain limits/config knobs remain. - Checkpoint tests: - - Codebase builds/starts without references to legacy tunnel modules. - - UI no longer calls old APIs or Socket.IO tunnel namespace. + - [ ] Codebase builds/starts without references to legacy tunnel modules. + - [ ] UI no longer calls old APIs or Socket.IO tunnel namespace. ### 7) End-to-End Validation — Milestone: End-to-End Validation (Windows) +- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). - Functional: - - Windows agent: WireGuard connect on port 30000; PowerShell MVP fully live in the web terminal; RDP/WinRM reachable over tunnel as configured. - - Idle timeout at 15 minutes; operator disconnect stops tunnel immediately. + - [ ] Windows agent: WireGuard connect on port 30000; PowerShell MVP fully live in the web terminal; RDP/WinRM reachable over tunnel as configured. + - [ ] Idle timeout at 15 minutes; operator disconnect stops tunnel immediately. - Security: - - Client-to-client blocked. - - Only engine IP reachable; per-agent ACL enforces allowed ports. - - Token enforcement blocks stale/unauthorized sessions. + - [ ] Client-to-client blocked. + - [ ] Only engine IP reachable; per-agent ACL enforces allowed ports. + - [ ] Token enforcement blocks stale/unauthorized sessions. - Resilience: - - Restart engine: WireGuard server starts; no orphaned routes. - - Restart agent: adapter persists; tunnel stays down until requested. + - [ ] Restart engine: WireGuard server starts; no orphaned routes. + - [ ] Restart agent: adapter persists; tunnel stays down until requested. - Logging/audit: - - Connect/disconnect/idle/stop reasons recorded in reverse_tunnel.log (Engine/Agent) and Device Activity. + - [ ] Connect/disconnect/idle/stop reasons recorded in reverse_tunnel.log (Engine/Agent) and Device Activity. - Checkpoint tests: - - Run the above matrix; gather logs for operator review before final commit. + - [ ] Run the above matrix; gather logs for operator review before final commit. ## Linux (Deferred) — Do Not Implement Yet - When greenlit, mirror the structure above for Linux: