Overhaul of VPN Codebase

This commit is contained in:
2025-12-18 01:35:03 -07:00
parent 2f81061a1b
commit 6ceb59f717
56 changed files with 1786 additions and 4778 deletions

View File

@@ -1,2 +0,0 @@
"""Reverse tunnel protocol modules (placeholder package)."""

View File

@@ -1,190 +0,0 @@
"""PowerShell channel implementation for reverse tunnel (Agent side)."""
from __future__ import annotations
import asyncio
import sys
import subprocess
from typing import Any, Dict, Optional
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
MSG_DATA = 0x05
MSG_WINDOW_UPDATE = 0x06
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
# Close codes (mirrored from engine framing)
CLOSE_OK = 0
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class PowershellChannel:
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
self._reader_task = None
self._writer_task = None
self._stdin_queue: asyncio.Queue = asyncio.Queue()
self._proc: Optional[asyncio.subprocess.Process] = None
self._exit_code: Optional[int] = None
self._frame_cls = getattr(role, "_frame_cls", None)
# ------------------------------------------------------------------ Helpers
def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0):
frame_cls = self._frame_cls
if frame_cls is None:
return None
try:
return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags)
except Exception:
return None
async def _send_frame(self, frame) -> None:
if frame is None:
return
await self.role._send_frame(self.tunnel, frame)
async def _send_close(self, code: int, reason: str) -> None:
try:
close_frame = getattr(self.role, "close_frame")
if callable(close_frame):
await self._send_frame(close_frame(self.channel_id, code, reason))
return
except Exception:
pass
frame = self._make_frame(
MSG_CLOSE,
payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"),
)
await self._send_frame(frame)
def _powershell_argv(self) -> list:
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe"
# Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin.
return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"]
# ------------------------------------------------------------------ Lifecycle
async def start(self) -> None:
if sys.platform.lower().startswith("win") is False:
self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
return
argv = self._powershell_argv()
self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} mode=pipes")
# Pipes (no PTY).
try:
self._proc = await asyncio.create_subprocess_exec(
*argv,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
)
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
return
self._reader_task = self.loop.create_task(self._pump_proc_stdout())
self._writer_task = self.loop.create_task(self._pump_proc_stdin())
self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)}")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type == MSG_DATA:
if frame.payload:
try:
self._stdin_queue.put_nowait(frame.payload)
except Exception:
await self._stdin_queue.put(frame.payload)
elif frame.msg_type == MSG_CONTROL:
await self._handle_control(frame.payload)
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
elif frame.msg_type == MSG_WINDOW_UPDATE:
# Reserved for back-pressure; ignore for now.
return
async def _handle_control(self, payload: bytes) -> None:
# No-op for pipe mode; resize is not supported here.
return
# -------------------- Pipe fallback pumps --------------------
async def _pump_proc_stdout(self) -> None:
try:
while self._proc and not self._closed:
chunk = await self._proc.stdout.read(4096)
if not chunk:
break
frame = self._make_frame(MSG_DATA, payload=bytes(chunk))
await self._send_frame(frame)
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdout pump error", error=True)
finally:
if self._proc and not self._closed:
try:
self._exit_code = await self._proc.wait()
except Exception:
pass
await self.stop(reason="stdout_closed")
async def _pump_proc_stdin(self) -> None:
try:
while self._proc and not self._closed:
data = await self._stdin_queue.get()
if self._closed or not self._proc or not self._proc.stdin:
break
try:
self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
await self._proc.stdin.drain()
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
break
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
if self._closed:
return
self._closed = True
if self._proc is not None:
try:
self._proc.terminate()
except Exception:
pass
current = asyncio.current_task()
if self._reader_task and self._reader_task is not current:
try:
self._reader_task.cancel()
except Exception:
pass
if self._writer_task and self._writer_task is not current:
try:
self._writer_task.cancel()
except Exception:
pass
# Include exit code in the close reason for debugging.
exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else ""
close_reason = (reason or "powershell_exit") + exit_suffix
# Always send CLOSE before socket teardown so engine/UI see the reason.
try:
await self._send_close(code, close_reason)
except Exception:
self.role._log("reverse_tunnel ps close send failed", error=True)
self.role._log(
f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={close_reason}"
)

View File

@@ -1,3 +0,0 @@
"""Namespace package for reverse tunnel domains (Agent side)."""
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]

View File

@@ -1,49 +0,0 @@
"""Placeholder Bash channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class BashChannel:
"""Stub Bash handler that immediately reports unsupported."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
# Until Bash support is implemented, close the channel to free resources.
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
# Ignore payloads but acknowledge by stopping the channel to avoid leaks.
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "bash_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel bash channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["BashChannel"]

View File

@@ -1,35 +0,0 @@
"""Expose the PowerShell channel under the domain path, with file-based import fallback."""
from __future__ import annotations
import importlib.util
from pathlib import Path
powershell_module = None
# Attempt package-relative import first
try: # pragma: no cover - best effort
from ....ReverseTunnel import tunnel_Powershell as powershell_module # type: ignore
except Exception:
powershell_module = None
# Fallback: load directly from file path to survive non-package runtimes
if powershell_module is None:
try:
base = Path(__file__).resolve().parents[3] / "ReverseTunnel" / "tunnel_Powershell.py"
spec = importlib.util.spec_from_file_location("tunnel_Powershell", base)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
powershell_module = module
except Exception:
powershell_module = None
if powershell_module and hasattr(powershell_module, "PowershellChannel"):
PowershellChannel = powershell_module.PowershellChannel # type: ignore
else: # pragma: no cover - safety guard
class PowershellChannel: # type: ignore
def __init__(self, *args, **kwargs):
raise ImportError("PowerShell channel unavailable")
__all__ = ["PowershellChannel"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for interactive shell tunnels (Agent side)."""
from .Powershell import PowershellChannel
from .Bash import BashChannel
__all__ = ["PowershellChannel", "BashChannel"]

View File

@@ -1,3 +0,0 @@
"""Interactive shell domain (PowerShell/Bash) handlers."""
__all__ = ["tunnel", "Protocols"]

View File

@@ -1,5 +0,0 @@
"""Placeholder module for remote interactive shell tunnel domain (Agent side)."""
DOMAIN_NAME = "remote-interactive-shell"
__all__ = ["DOMAIN_NAME"]

View File

@@ -1,47 +0,0 @@
"""Placeholder SSH channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class SSHChannel:
"""Stub SSH handler that marks the channel unsupported for now."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "ssh_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel ssh channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["SSHChannel"]

View File

@@ -1,47 +0,0 @@
"""Placeholder WinRM channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class WinRMChannel:
"""Stub WinRM handler that marks the channel unsupported for now."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "winrm_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel winrm channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["WinRMChannel"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote management tunnels (Agent side)."""
from .SSH import SSHChannel
from .WinRM import WinRMChannel
__all__ = ["SSHChannel", "WinRMChannel"]

View File

@@ -1,3 +0,0 @@
"""Remote management domain (SSH/WinRM) handlers."""
__all__ = ["tunnel", "Protocols"]

View File

@@ -1,5 +0,0 @@
"""Placeholder module for remote management domain (Agent side)."""
DOMAIN_NAME = "remote-management"
__all__ = ["DOMAIN_NAME"]

View File

@@ -1,47 +0,0 @@
"""Placeholder RDP channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class RDPChannel:
"""Stub RDP handler that marks the channel unsupported for now."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "rdp_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel rdp channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["RDPChannel"]

View File

@@ -1,47 +0,0 @@
"""Placeholder VNC channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class VNCChannel:
"""Stub VNC handler that marks the channel unsupported for now."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "vnc_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel vnc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["VNCChannel"]

View File

@@ -1,47 +0,0 @@
"""Placeholder WebRTC channel (Agent side)."""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Optional
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class WebRTCChannel:
"""Stub WebRTC handler that marks the channel unsupported for now."""
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
async def start(self) -> None:
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
if self._closed:
return
self._closed = True
try:
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "webrtc_closed"))
except Exception:
pass
self.role._log(f"reverse_tunnel webrtc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
__all__ = ["WebRTCChannel"]

View File

@@ -1,7 +0,0 @@
"""Protocol handlers for remote video tunnels (Agent side)."""
from .WebRTC import WebRTCChannel
from .RDP import RDPChannel
from .VNC import VNCChannel
__all__ = ["WebRTCChannel", "RDPChannel", "VNCChannel"]

View File

@@ -1,3 +0,0 @@
"""Remote video/desktop domain (RDP/VNC/WebRTC) handlers."""
__all__ = ["tunnel", "Protocols"]

View File

@@ -1,5 +0,0 @@
"""Placeholder module for remote video domain (Agent side)."""
DOMAIN_NAME = "remote-video"
__all__ = ["DOMAIN_NAME"]

View File

@@ -1,939 +0,0 @@
import asyncio
import base64
import importlib.util
import json
import os
import struct
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import urlparse
import aiohttp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
# Capture import errors for protocol handlers so we can report why they are missing.
PS_IMPORT_ERROR: Optional[str] = None
BASH_IMPORT_ERROR: Optional[str] = None
tunnel_SSH = None
tunnel_WinRM = None
tunnel_VNC = None
tunnel_RDP = None
tunnel_WebRTC = None
tunnel_Powershell = None
tunnel_Bash = None
def _load_protocol_module(module_name: str, rel_parts: list[str]) -> tuple[Optional[object], Optional[str]]:
"""Load a protocol handler directly from a file path to survive non-package runtimes."""
base = Path(__file__).parent
path = base
for part in rel_parts:
path = path / part
if not path.exists():
return None, f"path_missing:{path}"
try:
spec = importlib.util.spec_from_file_location(module_name, path)
if not spec or not spec.loader:
return None, "spec_failed"
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module, None
except Exception as exc: # pragma: no cover - defensive
return None, repr(exc)
try:
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Powershell as tunnel_Powershell # type: ignore
except Exception as exc: # pragma: no cover - best-effort logging only
PS_IMPORT_ERROR = repr(exc)
_module, _err = _load_protocol_module(
"tunnel_Powershell",
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Powershell.py"],
)
if _module:
tunnel_Powershell = _module
PS_IMPORT_ERROR = None
else:
try:
from .ReverseTunnel import tunnel_Powershell # type: ignore # legacy fallback
PS_IMPORT_ERROR = None
except Exception as exc2: # pragma: no cover - diagnostic only
PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | legacy_fallback={exc2!r} | file_load_failed={_err}"
try:
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Bash as tunnel_Bash # type: ignore
except Exception as exc: # pragma: no cover - best-effort logging only
BASH_IMPORT_ERROR = repr(exc)
_module, _err = _load_protocol_module(
"tunnel_Bash",
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Bash.py"],
)
if _module:
tunnel_Bash = _module
BASH_IMPORT_ERROR = None
else:
BASH_IMPORT_ERROR = f"{BASH_IMPORT_ERROR} | file_load_failed={_err}"
try:
from .Reverse_Tunnels.remote_management.Protocols import SSH as tunnel_SSH # type: ignore
except Exception:
_module, _err = _load_protocol_module(
"tunnel_SSH",
["Reverse_Tunnels", "remote_management", "Protocols", "SSH.py"],
)
tunnel_SSH = _module
try:
from .Reverse_Tunnels.remote_management.Protocols import WinRM as tunnel_WinRM # type: ignore
except Exception:
_module, _err = _load_protocol_module(
"tunnel_WinRM",
["Reverse_Tunnels", "remote_management", "Protocols", "WinRM.py"],
)
tunnel_WinRM = _module
try:
from .Reverse_Tunnels.remote_video.Protocols import VNC as tunnel_VNC # type: ignore
except Exception:
_module, _err = _load_protocol_module(
"tunnel_VNC",
["Reverse_Tunnels", "remote_video", "Protocols", "VNC.py"],
)
tunnel_VNC = _module
try:
from .Reverse_Tunnels.remote_video.Protocols import RDP as tunnel_RDP # type: ignore
except Exception:
_module, _err = _load_protocol_module(
"tunnel_RDP",
["Reverse_Tunnels", "remote_video", "Protocols", "RDP.py"],
)
tunnel_RDP = _module
try:
from .Reverse_Tunnels.remote_video.Protocols import WebRTC as tunnel_WebRTC # type: ignore
except Exception:
_module, _err = _load_protocol_module(
"tunnel_WebRTC",
["Reverse_Tunnels", "remote_video", "Protocols", "WebRTC.py"],
)
tunnel_WebRTC = _module
ROLE_NAME = "reverse_tunnel"
ROLE_CONTEXTS = ["interactive", "system"]
# Message types (keep in sync with Engine service)
MSG_CONNECT = 0x01
MSG_CONNECT_ACK = 0x02
MSG_CHANNEL_OPEN = 0x03
MSG_CHANNEL_ACK = 0x04
MSG_DATA = 0x05
MSG_WINDOW_UPDATE = 0x06
MSG_HEARTBEAT = 0x07
MSG_CLOSE = 0x08
MSG_CONTROL = 0x09
# Close codes
CLOSE_OK = 0
CLOSE_IDLE_TIMEOUT = 1
CLOSE_GRACE_EXPIRED = 2
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AUTH_FAILED = 4
CLOSE_SERVER_SHUTDOWN = 5
CLOSE_AGENT_SHUTDOWN = 6
CLOSE_DOMAIN_LIMIT = 7
CLOSE_UNEXPECTED_DISCONNECT = 8
FRAME_VERSION = 1
FRAME_HEADER_STRUCT = struct.Struct("<BBBBII") # version, msg_type, flags, reserved, channel_id, length
@dataclass
class TunnelFrame:
msg_type: int
channel_id: int
payload: bytes = field(default_factory=bytes)
flags: int = 0
version: int = FRAME_VERSION
reserved: int = 0
def encode(self) -> bytes:
payload = self.payload or b""
header = FRAME_HEADER_STRUCT.pack(
self.version,
self.msg_type,
self.flags,
self.reserved,
int(self.channel_id),
len(payload),
)
return header + payload
def decode_frame(raw: bytes) -> TunnelFrame:
if len(raw) < FRAME_HEADER_STRUCT.size:
raise ValueError("frame_too_small")
version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0)
if version != FRAME_VERSION:
raise ValueError(f"unsupported_version:{version}")
if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length:
raise ValueError("invalid_length")
payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length]
if len(payload) != length:
raise ValueError("length_mismatch")
return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved)
def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame:
flags = 0x1 if is_ack else 0x0
return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"")
def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame:
payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8")
return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload)
def _norm_text(value: Any) -> str:
if value is None:
return ""
try:
return str(value).strip()
except Exception:
return ""
def _is_literal_ip(value: str) -> bool:
try:
import ipaddress
ipaddress.ip_address(value.strip().strip("[]"))
return True
except Exception:
return False
@dataclass
class ActiveTunnel:
tunnel_id: str
domain: str
protocol: str
port: int
token: str
url: str
heartbeat_seconds: int
idle_seconds: int
grace_seconds: int
expires_at: Optional[int]
signing_key_hint: Optional[str] = None
session: Optional[aiohttp.ClientSession] = None
websocket: Optional[aiohttp.ClientWebSocketResponse] = None
tasks: list = field(default_factory=list)
send_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
channels: Dict[int, Any] = field(default_factory=dict)
last_activity: float = field(default_factory=lambda: time.time())
connected: bool = False
stopping: bool = False
stop_reason: Optional[str] = None
stop_origin: Optional[str] = None
class BaseChannel:
"""Placeholder channel handler; protocol-specific handlers plug in later."""
def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
async def start(self) -> None:
# Nothing to prime for placeholder channels.
return
async def on_frame(self, frame: TunnelFrame) -> None:
# Drop frames until protocol module is provided.
return
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason))
class Role:
def __init__(self, ctx):
self.ctx = ctx
self.sio = ctx.sio
self.loop = ctx.loop or asyncio.get_event_loop()
self.hooks = ctx.hooks or {}
self._http_client_factory = self.hooks.get("http_client")
self._log_hook = self.hooks.get("log_agent")
self._active: Dict[str, ActiveTunnel] = {}
self._domain_claims: Dict[str, str] = {}
self._domain_limits: Dict[str, Optional[int]] = {
"remote-interactive-shell": 2,
"remote-management": 1,
"remote-video": 2,
# Legacy / protocol fallbacks
"ps": 2,
"rdp": 1,
"vnc": 1,
"webrtc": 2,
"ssh": None,
"winrm": None,
}
self._default_heartbeat = 20
self._protocol_handlers: Dict[str, Any] = {}
self._frame_cls = TunnelFrame
self.close_frame = close_frame
try:
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
module_path = getattr(tunnel_Powershell, "__file__", None)
self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}")
else:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}",
error=True,
)
except Exception as exc:
self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True)
try:
if tunnel_Bash and hasattr(tunnel_Bash, "BashChannel"):
self._protocol_handlers["bash"] = tunnel_Bash.BashChannel
module_path = getattr(tunnel_Bash, "__file__", None)
self._log(f"reverse_tunnel bash handler registered (BashChannel) module={module_path}")
elif BASH_IMPORT_ERROR:
self._log(f"reverse_tunnel bash handler NOT registered (missing module/class) import_error={BASH_IMPORT_ERROR}", error=True)
except Exception as exc:
self._log(f"reverse_tunnel bash handler registration failed: {exc}", error=True)
try:
if tunnel_SSH and hasattr(tunnel_SSH, "SSHChannel"):
self._protocol_handlers["ssh"] = tunnel_SSH.SSHChannel
if tunnel_WinRM and hasattr(tunnel_WinRM, "WinRMChannel"):
self._protocol_handlers["winrm"] = tunnel_WinRM.WinRMChannel
if tunnel_VNC and hasattr(tunnel_VNC, "VNCChannel"):
self._protocol_handlers["vnc"] = tunnel_VNC.VNCChannel
if tunnel_RDP and hasattr(tunnel_RDP, "RDPChannel"):
self._protocol_handlers["rdp"] = tunnel_RDP.RDPChannel
if tunnel_WebRTC and hasattr(tunnel_WebRTC, "WebRTCChannel"):
self._protocol_handlers["webrtc"] = tunnel_WebRTC.WebRTCChannel
except Exception as exc:
self._log(f"reverse_tunnel protocol handler registration failed: {exc}", error=True)
# ------------------------------------------------------------------ Logging
def _log(self, message: str, *, error: bool = False) -> None:
fname = "reverse_tunnel.log"
try:
if callable(self._log_hook):
self._log_hook(message, fname=fname)
if error:
self._log_hook(message, fname="agent.error.log")
except Exception:
pass
# ------------------------------------------------------------------ Event wiring
def register_events(self):
@self.sio.on("reverse_tunnel_start")
async def _reverse_tunnel_start(payload):
await self._handle_tunnel_start(payload)
@self.sio.on("reverse_tunnel_stop")
async def _reverse_tunnel_stop(payload):
tid = ""
if isinstance(payload, dict):
tid = _norm_text(payload.get("tunnel_id"))
await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop")
# ------------------------------------------------------------------ Token helpers
def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]:
if not token:
raise ValueError("token_missing")
def _b64decode(segment: str) -> bytes:
padding = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + padding)
parts = token.split(".")
payload_segment = parts[0]
payload_bytes = _b64decode(payload_segment)
try:
payload = json.loads(payload_bytes.decode("utf-8"))
except Exception as exc:
raise ValueError("token_decode_error") from exc
if len(parts) == 2:
candidates = []
hint = _norm_text(signing_key_hint)
if hint:
candidates.append(hint)
client = self._http_client()
if client and hasattr(client, "load_server_signing_key"):
try:
stored = client.load_server_signing_key()
except Exception:
stored = None
if isinstance(stored, str) and stored.strip():
candidates.append(stored.strip())
signature = _b64decode(parts[1])
verified = False
for candidate in candidates:
try:
key_bytes = base64.b64decode(candidate, validate=True)
public_key = serialization.load_der_public_key(key_bytes)
except Exception:
continue
if not isinstance(public_key, ed25519.Ed25519PublicKey):
continue
try:
public_key.verify(signature, payload_bytes)
verified = True
if client and hasattr(client, "store_server_signing_key"):
try:
client.store_server_signing_key(candidate)
except Exception:
pass
break
except Exception:
continue
if not verified:
raise ValueError("token_signature_invalid")
return payload
def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]:
payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint)
def _matches(expected: str, actual: Any) -> bool:
return _norm_text(expected).lower() == _norm_text(actual).lower()
if expected_agent and not _matches(expected_agent, payload.get("agent_id")):
raise ValueError("token_agent_mismatch")
if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")):
raise ValueError("token_id_mismatch")
if expected_domain and not _matches(expected_domain, payload.get("domain")):
raise ValueError("token_domain_mismatch")
if expected_protocol and not _matches(expected_protocol, payload.get("protocol")):
raise ValueError("token_protocol_mismatch")
expires_at = payload.get("expires_at")
try:
expires_ts = int(expires_at) if expires_at is not None else None
except Exception:
expires_ts = None
if expires_ts is not None and expires_ts < int(time.time()):
raise ValueError("token_expired")
return payload
# ------------------------------------------------------------------ Utility
def _http_client(self):
try:
if callable(self._http_client_factory):
return self._http_client_factory()
except Exception:
return None
return None
def _port_allowed(self, port: int) -> bool:
return isinstance(port, int) and 1 <= port <= 65535
def _domain_allowed(self, domain: str) -> bool:
limit = self._domain_limits.get(domain)
if limit is None:
return True
active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping]
pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active]
count = len(set(active + pending))
return count < limit
def _build_ws_url(self, port: int) -> str:
client = self._http_client()
if client:
try:
client.refresh_base_url()
except Exception:
pass
base_url = getattr(client, "base_url", "") or "https://localhost:5000"
else:
base_url = "https://localhost:5000"
parsed = urlparse(base_url)
host = parsed.hostname or "localhost"
scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws"
return f"{scheme}://{host}:{port}/"
def _ssl_context(self, host: str):
client = self._http_client()
verify = getattr(getattr(client, "session", None), "verify", True)
if verify is False:
return False
if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"):
try:
ctx = client.key_store.build_ssl_context()
if ctx and _is_literal_ip(host):
ctx.check_hostname = False
return ctx
except Exception:
return None
return None
async def _emit_status(self, payload: Dict[str, Any]) -> None:
try:
await self.sio.emit("reverse_tunnel_status", payload)
except Exception:
pass
def _mark_activity(self, tunnel: ActiveTunnel) -> None:
tunnel.last_activity = time.time()
# ------------------------------------------------------------------ Event handlers
async def _handle_tunnel_start(self, payload: Any) -> None:
if not isinstance(payload, dict):
self._log("reverse_tunnel_start ignored: payload not a dict", error=True)
return
tunnel_id = _norm_text(payload.get("tunnel_id"))
token = _norm_text(payload.get("token"))
port = payload.get("port") or payload.get("assigned_port")
protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps"
domain = _norm_text(payload.get("domain") or protocol).lower() or protocol
heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20)
idle_seconds = int(payload.get("idle_seconds") or 3600)
grace_seconds = int(payload.get("grace_seconds") or 3600)
signing_key_hint = _norm_text(payload.get("signing_key"))
agent_hint = _norm_text(payload.get("agent_id"))
# Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts).
if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower():
return
if not token:
self._log("reverse_tunnel_start rejected: missing token", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"})
return
if not tunnel_id:
tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias
try:
claims = self._validate_token(
token,
expected_agent=self.ctx.agent_id,
expected_domain=domain,
expected_protocol=protocol,
expected_tunnel=tunnel_id,
signing_key_hint=signing_key_hint,
)
except Exception as exc:
if str(exc) == "token_agent_mismatch":
# Broadcast hit the wrong agent context; ignore quietly.
return
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
return
tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id)
if not tunnel_id:
self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True)
return
try:
port = int(port or claims.get("assigned_port") or 0)
except Exception:
port = 0
if not self._port_allowed(port):
self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"})
return
domain = _norm_text(claims.get("domain") or domain).lower()
protocol = _norm_text(claims.get("protocol") or protocol).lower()
expires_at = claims.get("expires_at")
if tunnel_id in self._active:
self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}")
return
if not self._domain_allowed(domain):
self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"})
return
url = self._build_ws_url(port)
parsed = urlparse(url)
heartbeat_seconds = max(5, min(heartbeat_seconds, 120))
idle_seconds = max(30, idle_seconds)
grace_seconds = max(60, grace_seconds)
tunnel = ActiveTunnel(
tunnel_id=tunnel_id,
domain=domain,
protocol=protocol,
port=port,
token=token,
url=url,
heartbeat_seconds=heartbeat_seconds,
idle_seconds=idle_seconds,
grace_seconds=grace_seconds,
expires_at=int(expires_at) if expires_at is not None else None,
signing_key_hint=signing_key_hint or None,
)
self._active[tunnel_id] = tunnel
self._domain_claims[domain] = tunnel_id
self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}")
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url})
task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost"))
tunnel.tasks.append(task)
# ------------------------------------------------------------------ Core tunnel handling
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
ssl_ctx = self._ssl_context(host)
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
self._log(
f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} "
f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}"
)
try:
tunnel.session = aiohttp.ClientSession(timeout=timeout)
tunnel.websocket = await tunnel.session.ws_connect(
tunnel.url,
ssl=ssl_ctx,
heartbeat=None,
max_msg_size=0,
timeout=timeout,
)
self._mark_activity(tunnel)
self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}")
await tunnel.websocket.send_bytes(
TunnelFrame(
msg_type=MSG_CONNECT,
channel_id=0,
payload=json.dumps(
{
"agent_id": self.ctx.agent_id,
"tunnel_id": tunnel.tunnel_id,
"token": tunnel.token,
"protocol": tunnel.protocol,
"domain": tunnel.domain,
"version": FRAME_VERSION,
},
separators=(",", ":"),
).encode("utf-8"),
).encode()
)
self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}")
sender = self.loop.create_task(self._pump_sender(tunnel))
receiver = self.loop.create_task(self._pump_receiver(tunnel))
heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel))
watchdog = self.loop.create_task(self._watchdog(tunnel))
tunnel.tasks.extend([sender, receiver, heartbeats, watchdog])
task_labels = {
sender: "sender",
receiver: "receiver",
heartbeats: "heartbeat",
watchdog: "watchdog",
}
done, pending = await asyncio.wait(task_labels.keys(), return_when=asyncio.FIRST_COMPLETED)
for finished in done:
label = task_labels.get(finished) or "unknown"
exc_text = ""
try:
exc_obj = finished.exception()
except asyncio.CancelledError:
exc_obj = None
exc_text = " (cancelled)"
except Exception as exc: # pragma: no cover - defensive logging
exc_obj = exc
if exc_obj:
exc_text = f" (exc={exc_obj!r})"
if not tunnel.stop_reason:
tunnel.stop_reason = f"{label}_stopped{exc_text}"
if not tunnel.stop_origin:
tunnel.stop_origin = label
self._log(
f"reverse_tunnel task completed tunnel_id={tunnel.tunnel_id} task={label} stop_reason={tunnel.stop_reason}{exc_text}"
)
if pending:
try:
self._log(
"reverse_tunnel pending tasks after first completion tunnel_id=%s pending=%s",
# Represent pending tasks by label for debugging.
)
except Exception:
pass
except Exception as exc:
self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True)
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"})
finally:
try:
ws = tunnel.websocket
self._log(
f"reverse_tunnel ws closing tunnel_id={tunnel.tunnel_id} "
f"code={getattr(ws, 'close_code', None)} reason={getattr(ws, 'close_reason', None)}"
)
except Exception:
pass
await self._shutdown_tunnel(tunnel)
async def _pump_sender(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
frame: TunnelFrame = await tunnel.send_queue.get()
try:
await tunnel.websocket.send_bytes(frame.encode())
self._mark_activity(tunnel)
self._log(
f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
except Exception:
if not tunnel.stop_reason:
tunnel.stop_reason = "sender_error"
break
except asyncio.CancelledError:
pass
except Exception:
if not tunnel.stop_reason:
tunnel.stop_reason = "sender_failed"
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}")
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
ws = tunnel.websocket
if ws is None:
return
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.BINARY:
try:
frame = decode_frame(msg.data)
except Exception:
self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True)
continue
self._mark_activity(tunnel)
await self._handle_frame(tunnel, frame)
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
self._log(
f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} "
f"code={ws.close_code} reason={ws.close_reason}"
)
tunnel.stop_reason = ws.close_reason or "ws_closed"
break
except asyncio.CancelledError:
pass
except Exception:
if not tunnel.stop_reason:
tunnel.stop_reason = "receiver_failed"
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}")
# If no stop_reason was set, emit a CLOSE so engine/UI see a reason.
if not tunnel.stop_reason:
try:
await self._send_frame(tunnel, close_frame(0, CLOSE_UNEXPECTED_DISCONNECT, "receiver_stop"))
tunnel.stop_reason = "receiver_stop"
except Exception:
pass
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(tunnel.heartbeat_seconds)
await self._send_frame(tunnel, heartbeat_frame())
self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}")
except asyncio.CancelledError:
pass
except Exception:
if not tunnel.stop_reason:
tunnel.stop_reason = "heartbeat_failed"
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}")
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(10)
now = time.time()
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
tunnel.stop_reason = tunnel.stop_reason or "idle_timeout"
self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}")
break
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
tunnel.stop_reason = tunnel.stop_reason or "grace_expired"
self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}")
break
except asyncio.CancelledError:
pass
except Exception:
if not tunnel.stop_reason:
tunnel.stop_reason = "watchdog_failed"
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}")
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
self._log(
f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
if frame.msg_type == MSG_HEARTBEAT:
if frame.flags & 0x1:
self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}")
return
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
return
if frame.msg_type == MSG_CONNECT_ACK:
tunnel.connected = True
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}")
return
if frame.msg_type == MSG_CHANNEL_OPEN:
await self._handle_channel_open(tunnel, frame)
return
if frame.msg_type == MSG_CLOSE:
try:
reason = json.loads(frame.payload.decode("utf-8"))
except Exception:
reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"}
tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None
await self._shutdown_tunnel(tunnel, send_close=False)
return
handler = tunnel.channels.get(frame.channel_id)
if handler:
try:
await handler.on_frame(frame)
except Exception:
self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True)
else:
if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE):
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel"))
async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
protocol = _norm_text(payload.get("protocol") or tunnel.protocol)
metadata = payload.get("metadata") if isinstance(payload, dict) else {}
if frame.channel_id in tunnel.channels:
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
return
if protocol.lower() == "ps" and "ps" not in self._protocol_handlers:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}",
error=True,
)
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
try:
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
except Exception:
self._log(f"reverse_tunnel channel handler fallback to BaseChannel protocol={protocol}", error=True)
handler = BaseChannel(self, tunnel, frame.channel_id, metadata)
tunnel.channels[frame.channel_id] = handler
await handler.start()
await self._send_frame(
tunnel,
TunnelFrame(
msg_type=MSG_CHANNEL_ACK,
channel_id=frame.channel_id,
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
),
)
self._log(
f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} "
f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}"
)
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
if tunnel.stopping and getattr(frame, "msg_type", None) != MSG_CLOSE:
return
try:
tunnel.send_queue.put_nowait(frame)
except Exception:
await tunnel.send_queue.put(frame)
async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None:
tunnel = self._active.get(tunnel_id)
if not tunnel:
return
if not tunnel.stop_origin:
tunnel.stop_origin = "stop_tunnel"
self._log(f"reverse_tunnel stop_tunnel requested tunnel_id={tunnel_id} code={code} reason={reason}")
if not tunnel.stop_reason:
tunnel.stop_reason = reason or "requested"
await self._send_frame(tunnel, close_frame(0, code, reason))
await self._shutdown_tunnel(tunnel, send_close=False)
async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None:
if tunnel.stopping:
return
tunnel.stopping = True
reason_text = tunnel.stop_reason or "closed"
if not tunnel.stop_reason:
tunnel.stop_reason = reason_text
if not tunnel.stop_origin:
tunnel.stop_origin = "shutdown"
self._log(
f"reverse_tunnel shutdown start tunnel_id={tunnel.tunnel_id} stop_reason={tunnel.stop_reason} "
f"stop_origin={tunnel.stop_origin} ws_closed={getattr(tunnel.websocket, 'closed', None)}"
)
# Stop all channels first so CLOSE frames (with reasons) are sent upstream.
for handler in list(tunnel.channels.values()):
try:
await handler.stop(code=CLOSE_UNEXPECTED_DISCONNECT, reason=reason_text or "tunnel_shutdown")
except Exception:
pass
if send_close:
close_payload = close_frame(0, CLOSE_AGENT_SHUTDOWN, reason_text or "agent_shutdown")
try:
await self._send_frame(tunnel, close_payload)
# Give the sender loop a brief window to flush the CLOSE upstream.
await asyncio.sleep(0.05)
except Exception:
pass
# Fallback: if sender task died, try sending directly on the websocket.
try:
if tunnel.websocket and not tunnel.websocket.closed:
await tunnel.websocket.send_bytes(close_payload.encode())
except Exception:
pass
for task in list(tunnel.tasks):
try:
task.cancel()
except Exception:
pass
if tunnel.websocket is not None:
try:
message = (reason_text or "agent_shutdown").encode("utf-8", "ignore")[:120]
await tunnel.websocket.close(message=message)
except Exception:
pass
if tunnel.session is not None:
try:
await tunnel.session.close()
except Exception:
pass
self._active.pop(tunnel.tunnel_id, None)
if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id:
self._domain_claims.pop(tunnel.domain, None)
self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}")
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"})
# ------------------------------------------------------------------ Lifecycle
def stop_all(self):
for tunnel_id in list(self._active.keys()):
try:
self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown"))
except Exception:
pass

View File

@@ -0,0 +1,167 @@
# ======================================================
# Data\Agent\Roles\role_VpnShell.py
# Description: PowerShell TCP server for VPN shell access (Engine connects over WireGuard /32).
#
# API Endpoints (if applicable): None
# ======================================================
"""VPN PowerShell server for the WireGuard tunnel."""
from __future__ import annotations
import base64
import json
import socket
import subprocess
import threading
import time
from pathlib import Path
from typing import Any, Optional
import os
ROLE_NAME = "VpnShell"
ROLE_CONTEXTS = ["system"]
def _log_path() -> Path:
root = Path(__file__).resolve().parents[2] / "Logs"
root.mkdir(parents=True, exist_ok=True)
return root / "reverse_tunnel.log"
def _write_log(message: str) -> None:
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
try:
_log_path().open("a", encoding="utf-8").write(f"[{ts}] [vpn-shell] {message}\n")
except Exception:
pass
def _b64encode(data: bytes) -> str:
return base64.b64encode(data).decode("ascii").strip()
def _b64decode(value: str) -> bytes:
return base64.b64decode(value.encode("ascii"))
def _resolve_shell_port() -> int:
raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT")
try:
value = int(raw) if raw is not None else 47001
except Exception:
value = 47001
if value < 1 or value > 65535:
return 47001
return value
class ShellSession:
def __init__(self, conn: socket.socket, address: tuple[str, int]) -> None:
self.conn = conn
self.address = address
self.proc: Optional[subprocess.Popen] = None
self._stop = threading.Event()
def start(self) -> None:
self.proc = subprocess.Popen(
["powershell.exe", "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
bufsize=0,
)
threading.Thread(target=self._reader_loop, daemon=True).start()
self._writer_loop()
def _reader_loop(self) -> None:
if not self.proc or not self.proc.stdout:
return
try:
while not self._stop.is_set():
chunk = self.proc.stdout.readline()
if not chunk:
break
payload = json.dumps({"type": "stdout", "data": _b64encode(chunk)})
self.conn.sendall(payload.encode("utf-8") + b"\n")
except Exception:
pass
def _writer_loop(self) -> None:
buffer = b""
try:
while not self._stop.is_set():
data = self.conn.recv(4096)
if not data:
break
buffer += data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
try:
msg = json.loads(line.decode("utf-8"))
except Exception:
continue
if msg.get("type") == "stdin":
payload = msg.get("data") or ""
if self.proc and self.proc.stdin:
try:
self.proc.stdin.write(_b64decode(str(payload)))
self.proc.stdin.flush()
except Exception:
pass
if msg.get("type") == "close":
self._stop.set()
break
finally:
self.close()
def close(self) -> None:
self._stop.set()
try:
self.conn.close()
except Exception:
pass
if self.proc:
try:
self.proc.terminate()
except Exception:
pass
class ShellServer:
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None:
self.host = host
self.port = port or _resolve_shell_port()
self._thread = threading.Thread(target=self._serve, daemon=True)
self._thread.start()
_write_log(f"VPN shell server listening on {self.host}:{self.port}")
def _serve(self) -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind((self.host, self.port))
server.listen(5)
while True:
conn, addr = server.accept()
remote_ip = addr[0]
if not remote_ip.startswith("10.255."):
_write_log(f"Rejected shell connection from {remote_ip}")
conn.close()
continue
_write_log(f"Accepted shell connection from {remote_ip}")
session = ShellSession(conn, addr)
threading.Thread(target=session.start, daemon=True).start()
class Role:
def __init__(self, ctx) -> None:
self.ctx = ctx
self.server = ShellServer()
def register_events(self) -> None:
return
def stop_all(self) -> None:
return

View File

@@ -9,13 +9,15 @@
This role prepares the WireGuard client config, manages a single active This role prepares the WireGuard client config, manages a single active
session, enforces idle teardown, and logs lifecycle events to session, enforces idle teardown, and logs lifecycle events to
Agent/Logs/reverse_tunnel.log. It does not yet bind to engine signals; higher Agent/Logs/reverse_tunnel.log. It binds to Engine Socket.IO events
layers should call start_session/stop_session with the issued config/token. (`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop
the client session with the issued config/token.
""" """
from __future__ import annotations from __future__ import annotations
import base64 import base64
import json
import os import os
import subprocess import subprocess
import threading import threading
@@ -26,6 +28,7 @@ from typing import Any, Dict, Optional
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.asymmetric import x25519
from signature_utils import verify_and_store_script_signature
ROLE_NAME = "WireGuardTunnel" ROLE_NAME = "WireGuardTunnel"
ROLE_CONTEXTS = ["system"] ROLE_CONTEXTS = ["system"]
@@ -95,6 +98,8 @@ class SessionConfig:
allowed_ports: str allowed_ports: str
idle_seconds: int = 900 idle_seconds: int = 900
preshared_key: Optional[str] = None preshared_key: Optional[str] = None
client_private_key: Optional[str] = None
client_public_key: Optional[str] = None
class WireGuardClient: class WireGuardClient:
@@ -122,17 +127,35 @@ class WireGuardClient:
return candidate return candidate
return "wireguard.exe" return "wireguard.exe"
def _validate_token(self, token: Dict[str, Any]) -> None: def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
required = ("agent_id", "tunnel_id", "expires_at") payload = dict(token or {})
signature = payload.pop("signature", None)
signing_key = payload.pop("signing_key", None)
sig_alg = payload.pop("sig_alg", None)
required = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required if field not in token or token[field] in ("", None)] missing = [field for field in required if field not in token or token[field] in ("", None)]
if missing: if missing:
raise ValueError(f"Missing token fields: {', '.join(missing)}") raise ValueError(f"Missing token fields: {', '.join(missing)}")
try: try:
exp = float(token["expires_at"]) exp = float(payload["expires_at"])
except Exception: except Exception:
raise ValueError("Invalid token expiry") raise ValueError("Invalid token expiry")
if exp <= time.time(): if exp <= time.time():
raise ValueError("Token expired") raise ValueError("Token expired")
try:
port = int(payload["port"])
except Exception:
raise ValueError("Invalid token port")
if port < 1 or port > 65535:
raise ValueError("Invalid token port")
if signature:
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
raise ValueError("Unsupported token signature algorithm")
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key):
raise ValueError("Token signature invalid")
def _run(self, args: list[str]) -> tuple[int, str, str]: def _run(self, args: list[str]) -> tuple[int, str, str]:
try: try:
@@ -142,9 +165,10 @@ class WireGuardClient:
return 1, "", str(exc) return 1, "", str(exc)
def _render_config(self, session: SessionConfig) -> str: def _render_config(self, session: SessionConfig) -> str:
private_key = session.client_private_key or self._client_keys["private"]
lines = [ lines = [
"[Interface]", "[Interface]",
f"PrivateKey = {self._client_keys['private']}", f"PrivateKey = {private_key}",
f"Address = {session.virtual_ip}", f"Address = {session.virtual_ip}",
"", "",
"[Peer]", "[Peer]",
@@ -172,13 +196,13 @@ class WireGuardClient:
t.start() t.start()
self._idle_thread = t self._idle_thread = t
def start_session(self, session: SessionConfig) -> None: def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None:
if self.session: if self.session:
_write_log("Rejecting start_session: existing session already active.") _write_log("Rejecting start_session: existing session already active.")
return return
try: try:
self._validate_token(session.token) self._validate_token(session.token, signing_client=signing_client)
except Exception as exc: except Exception as exc:
_write_log(f"Refusing to start WireGuard session: {exc}") _write_log(f"Refusing to start WireGuard session: {exc}")
return return
@@ -243,6 +267,7 @@ class Role:
self.client = client self.client = client
hooks = getattr(ctx, "hooks", {}) or {} hooks = getattr(ctx, "hooks", {}) or {}
self._log_hook = hooks.get("log_agent") self._log_hook = hooks.get("log_agent")
self._http_client_factory = hooks.get("http_client")
def _log(self, message: str, *, error: bool = False) -> None: def _log(self, message: str, *, error: bool = False) -> None:
if callable(self._log_hook): if callable(self._log_hook):
@@ -254,6 +279,14 @@ class Role:
pass pass
_write_log(message) _write_log(message)
def _http_client(self) -> Optional[Any]:
try:
if callable(self._http_client_factory):
return self._http_client_factory()
except Exception:
return None
return None
def _build_session(self, payload: Any) -> Optional[SessionConfig]: def _build_session(self, payload: Any) -> Optional[SessionConfig]:
if not isinstance(payload, dict): if not isinstance(payload, dict):
self._log("WireGuard start payload missing/invalid.", error=True) self._log("WireGuard start payload missing/invalid.", error=True)
@@ -299,6 +332,8 @@ class Role:
allowed_ports=allowed_ports, allowed_ports=allowed_ports,
idle_seconds=idle_seconds, idle_seconds=idle_seconds,
preshared_key=payload.get("preshared_key"), preshared_key=payload.get("preshared_key"),
client_private_key=payload.get("client_private_key"),
client_public_key=payload.get("client_public_key"),
) )
def register_events(self) -> None: def register_events(self) -> None:
@@ -310,7 +345,7 @@ class Role:
if not session: if not session:
return return
self._log("WireGuard start request received.") self._log("WireGuard start request received.")
self.client.start_session(session) self.client.start_session(session, signing_client=self._http_client())
@sio.on("vpn_tunnel_stop") @sio.on("vpn_tunnel_stop")
async def _vpn_tunnel_stop(payload): async def _vpn_tunnel_stop(payload):

View File

@@ -1,90 +0,0 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel.py
# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# ======================================================
from __future__ import annotations
import base64
import json
import pytest
from .conftest import EngineTestHarness
def _client_with_admin_session(harness: EngineTestHarness):
client = harness.app.test_client()
with client.session_transaction() as sess:
sess["username"] = "admin"
sess["role"] = "Admin"
return client
def _decode_token_segment(token: str) -> dict:
"""Decode the unsigned payload segment from the tunnel token."""
if not token:
return {}
segment = token.split(".")[0]
padding = "=" * (-len(segment) % 4)
raw = base64.urlsafe_b64decode(segment + padding)
try:
return json.loads(raw.decode("utf-8"))
except Exception:
return {}
@pytest.mark.parametrize("agent_id", ["test-device-agent"])
def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload["agent_id"] == agent_id
assert payload["protocol"] == "ps"
assert payload["domain"] == "ps"
assert isinstance(payload["port"], int) and payload["port"] >= 30000
assert payload.get("token")
claims = _decode_token_segment(payload["token"])
assert claims.get("agent_id") == agent_id
assert claims.get("protocol") == "ps"
assert claims.get("domain") == "ps"
assert claims.get("tunnel_id") == payload["tunnel_id"]
assert claims.get("assigned_port") == payload["port"]
def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
first = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert first.status_code == 200
second = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert second.status_code == 409
data = second.get_json()
assert data.get("error") == "domain_limit"
def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None:
client = _client_with_admin_session(engine_harness)
resp = client.post(
"/api/tunnel/request",
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
)
assert resp.status_code == 200
payload = resp.get_json()
assert payload.get("idle_seconds") and payload["idle_seconds"] > 0
assert payload.get("grace_seconds") and payload["grace_seconds"] > 0
assert payload.get("expires_at") and int(payload["expires_at"]) > 0

View File

@@ -1,101 +0,0 @@
# ======================================================
# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py
# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip
# against a running Engine + Agent (requires live services).
#
# Requirements:
# - Environment variables must be set to point at a live Engine + Agent:
# TUNNEL_TEST_HOST (e.g., https://localhost:5000)
# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device)
# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator)
# - A live Agent must be reachable and allowed to establish the reverse tunnel.
# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable).
#
# API Endpoints (if applicable):
# - POST /api/tunnel/request
# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll)
# ======================================================
from __future__ import annotations
import os
import time
import pytest
import requests
import socketio
HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip()
AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip()
BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip()
VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower()
VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True
SKIP_MSG = (
"Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)"
)
def _require_env() -> None:
if not HOST or not AGENT_ID or not BEARER:
pytest.skip(SKIP_MSG)
def _make_session() -> requests.Session:
sess = requests.Session()
sess.verify = VERIFY
sess.headers.update({"Authorization": f"Bearer {BEARER}"})
return sess
def test_reverse_tunnel_powershell_roundtrip() -> None:
_require_env()
sess = _make_session()
# 1) Request a tunnel lease
resp = sess.post(
f"{HOST}/api/tunnel/request",
json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "remote-interactive-shell"},
)
assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}"
lease = resp.json()
tunnel_id = lease["tunnel_id"]
# 2) Connect to Socket.IO /tunnel namespace
sio = socketio.Client()
sio.connect(
HOST,
namespaces=["/tunnel"],
headers={"Authorization": f"Bearer {BEARER}"},
transports=["websocket"],
wait_timeout=10,
)
# 3) Join tunnel and open PS channel
join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10)
assert join_resp.get("status") == "ok", f"join failed: {join_resp}"
open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10)
assert not open_resp.get("error"), f"ps_open failed: {open_resp}"
# 4) Send a command
send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10)
assert not send_resp.get("error"), f"ps_send failed: {send_resp}"
# 5) Poll for output
output_text = ""
deadline = time.time() + 15
while time.time() < deadline:
poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10)
if poll_resp.get("error"):
pytest.fail(f"ps_poll failed: {poll_resp}")
lines = poll_resp.get("output") or []
output_text += "".join(lines)
if "Hello World" in output_text:
break
time.sleep(0.5)
sio.disconnect()
assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}"

View File

@@ -77,17 +77,12 @@ LOG_ROOT = PROJECT_ROOT / "Engine" / "Logs"
LOG_FILE_PATH = LOG_ROOT / "engine.log" LOG_FILE_PATH = LOG_ROOT / "engine.log"
ERROR_LOG_FILE_PATH = LOG_ROOT / "error.log" ERROR_LOG_FILE_PATH = LOG_ROOT / "error.log"
API_LOG_FILE_PATH = LOG_ROOT / "api.log" API_LOG_FILE_PATH = LOG_ROOT / "api.log"
REVERSE_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log" VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
DEFAULT_TUNNEL_FIXED_PORT = 8443
DEFAULT_TUNNEL_PORT_RANGE = (30000, 40000)
DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS = 3600
DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS = 3600
DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS = 20
DEFAULT_WIREGUARD_PORT = 30000 DEFAULT_WIREGUARD_PORT = 30000
DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32" DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32"
DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24" DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24"
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478) DEFAULT_WIREGUARD_SHELL_PORT = 47001
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT)
VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server" VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server"
@@ -282,18 +277,14 @@ class EngineSettings:
error_log_file: str error_log_file: str
api_log_file: str api_log_file: str
api_groups: Tuple[str, ...] api_groups: Tuple[str, ...]
reverse_tunnel_fixed_port: int vpn_tunnel_log_file: str
reverse_tunnel_port_range: Tuple[int, int]
reverse_tunnel_idle_timeout_seconds: int
reverse_tunnel_grace_timeout_seconds: int
reverse_tunnel_heartbeat_seconds: int
reverse_tunnel_log_file: str
wireguard_port: int wireguard_port: int
wireguard_engine_virtual_ip: str wireguard_engine_virtual_ip: str
wireguard_peer_network: str wireguard_peer_network: str
wireguard_server_private_key_path: str wireguard_server_private_key_path: str
wireguard_server_public_key_path: str wireguard_server_public_key_path: str
wireguard_acl_allowlist_windows: Tuple[int, ...] wireguard_acl_allowlist_windows: Tuple[int, ...]
wireguard_shell_port: int
raw: MutableMapping[str, Any] = field(default_factory=dict) raw: MutableMapping[str, Any] = field(default_factory=dict)
def to_flask_config(self) -> MutableMapping[str, Any]: def to_flask_config(self) -> MutableMapping[str, Any]:
@@ -390,10 +381,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
api_log_file = str(runtime_config.get("API_LOG_FILE") or API_LOG_FILE_PATH) api_log_file = str(runtime_config.get("API_LOG_FILE") or API_LOG_FILE_PATH)
_ensure_parent(Path(api_log_file)) _ensure_parent(Path(api_log_file))
reverse_tunnel_log_file = str( vpn_tunnel_log_file = str(
runtime_config.get("REVERSE_TUNNEL_LOG_FILE") or REVERSE_TUNNEL_LOG_FILE_PATH runtime_config.get("VPN_TUNNEL_LOG_FILE")
or runtime_config.get("WIREGUARD_LOG_FILE")
or os.environ.get("BOREALIS_VPN_TUNNEL_LOG_FILE")
or os.environ.get("BOREALIS_WIREGUARD_LOG_FILE")
or VPN_TUNNEL_LOG_FILE_PATH
) )
_ensure_parent(Path(reverse_tunnel_log_file)) _ensure_parent(Path(vpn_tunnel_log_file))
wireguard_port = _parse_int( wireguard_port = _parse_int(
runtime_config.get("WIREGUARD_PORT") or os.environ.get("BOREALIS_WIREGUARD_PORT"), runtime_config.get("WIREGUARD_PORT") or os.environ.get("BOREALIS_WIREGUARD_PORT"),
@@ -416,6 +411,13 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
or os.environ.get("BOREALIS_WIREGUARD_WINDOWS_ALLOWLIST"), or os.environ.get("BOREALIS_WIREGUARD_WINDOWS_ALLOWLIST"),
default=DEFAULT_WIREGUARD_ACL_WINDOWS, default=DEFAULT_WIREGUARD_ACL_WINDOWS,
) )
wireguard_shell_port = _parse_int(
runtime_config.get("WIREGUARD_SHELL_PORT")
or os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT"),
default=DEFAULT_WIREGUARD_SHELL_PORT,
minimum=1,
maximum=65535,
)
wireguard_key_root = Path( wireguard_key_root = Path(
runtime_config.get("WIREGUARD_KEY_ROOT") runtime_config.get("WIREGUARD_KEY_ROOT")
or os.environ.get("BOREALIS_WIREGUARD_KEY_ROOT") or os.environ.get("BOREALIS_WIREGUARD_KEY_ROOT")
@@ -440,35 +442,6 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
"scheduled_jobs", "scheduled_jobs",
) )
tunnel_fixed_port = _parse_int(
runtime_config.get("TUNNEL_FIXED_PORT") or os.environ.get("BOREALIS_TUNNEL_FIXED_PORT"),
default=DEFAULT_TUNNEL_FIXED_PORT,
minimum=1,
maximum=65535,
)
tunnel_port_range = _parse_port_range(
runtime_config.get("TUNNEL_PORT_RANGE") or os.environ.get("BOREALIS_TUNNEL_PORT_RANGE"),
default=DEFAULT_TUNNEL_PORT_RANGE,
)
tunnel_idle_timeout_seconds = _parse_int(
runtime_config.get("TUNNEL_IDLE_TIMEOUT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_IDLE_TIMEOUT_SECONDS"),
default=DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS,
minimum=60,
)
tunnel_grace_timeout_seconds = _parse_int(
runtime_config.get("TUNNEL_GRACE_TIMEOUT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_GRACE_TIMEOUT_SECONDS"),
default=DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS,
minimum=60,
)
tunnel_heartbeat_seconds = _parse_int(
runtime_config.get("TUNNEL_HEARTBEAT_SECONDS")
or os.environ.get("BOREALIS_TUNNEL_HEARTBEAT_SECONDS"),
default=DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS,
minimum=5,
)
settings = EngineSettings( settings = EngineSettings(
database_path=database_path, database_path=database_path,
static_folder=static_folder, static_folder=static_folder,
@@ -484,18 +457,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
error_log_file=str(error_log_file), error_log_file=str(error_log_file),
api_log_file=str(api_log_file), api_log_file=str(api_log_file),
api_groups=api_groups, api_groups=api_groups,
reverse_tunnel_fixed_port=tunnel_fixed_port, vpn_tunnel_log_file=vpn_tunnel_log_file,
reverse_tunnel_port_range=tunnel_port_range,
reverse_tunnel_idle_timeout_seconds=tunnel_idle_timeout_seconds,
reverse_tunnel_grace_timeout_seconds=tunnel_grace_timeout_seconds,
reverse_tunnel_heartbeat_seconds=tunnel_heartbeat_seconds,
reverse_tunnel_log_file=reverse_tunnel_log_file,
wireguard_port=wireguard_port, wireguard_port=wireguard_port,
wireguard_engine_virtual_ip=wireguard_engine_virtual_ip, wireguard_engine_virtual_ip=wireguard_engine_virtual_ip,
wireguard_peer_network=wireguard_peer_network, wireguard_peer_network=wireguard_peer_network,
wireguard_server_private_key_path=wireguard_server_private_key_path, wireguard_server_private_key_path=wireguard_server_private_key_path,
wireguard_server_public_key_path=wireguard_server_public_key_path, wireguard_server_public_key_path=wireguard_server_public_key_path,
wireguard_acl_allowlist_windows=wireguard_acl_allowlist_windows, wireguard_acl_allowlist_windows=wireguard_acl_allowlist_windows,
wireguard_shell_port=wireguard_shell_port,
raw=runtime_config, raw=runtime_config,
) )
return settings return settings

View File

@@ -2,6 +2,8 @@
# Data\Engine\database_migrations.py # Data\Engine\database_migrations.py
# Description: Provides schema evolution helpers for the Engine sqlite # Description: Provides schema evolution helpers for the Engine sqlite
# database without importing the legacy ``Modules`` package. # database without importing the legacy ``Modules`` package.
#
# API Endpoints (if applicable): None
# ====================================================== # ======================================================
"""Engine database schema migration helpers.""" """Engine database schema migration helpers."""
@@ -24,6 +26,7 @@ def apply_all(conn: sqlite3.Connection) -> None:
_ensure_devices_table(conn) _ensure_devices_table(conn)
_ensure_device_aux_tables(conn) _ensure_device_aux_tables(conn)
_ensure_device_vpn_config_table(conn)
_ensure_refresh_token_table(conn) _ensure_refresh_token_table(conn)
_ensure_install_code_table(conn) _ensure_install_code_table(conn)
_ensure_install_code_persistence_table(conn) _ensure_install_code_persistence_table(conn)
@@ -112,6 +115,20 @@ def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
) )
def _ensure_device_vpn_config_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS device_vpn_config (
agent_id TEXT PRIMARY KEY,
allowed_ports TEXT,
updated_at TEXT,
updated_by TEXT
)
"""
)
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None: def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
cur = conn.cursor() cur = conn.cursor()
cur.execute( cur.execute(

View File

@@ -120,18 +120,14 @@ class EngineContext:
config: Mapping[str, Any] config: Mapping[str, Any]
api_groups: Sequence[str] api_groups: Sequence[str]
api_log_path: str api_log_path: str
reverse_tunnel_fixed_port: int vpn_tunnel_log_path: str
reverse_tunnel_port_range: Tuple[int, int]
reverse_tunnel_idle_timeout_seconds: int
reverse_tunnel_grace_timeout_seconds: int
reverse_tunnel_heartbeat_seconds: int
reverse_tunnel_log_path: str
wireguard_port: int wireguard_port: int
wireguard_engine_virtual_ip: str wireguard_engine_virtual_ip: str
wireguard_peer_network: str wireguard_peer_network: str
wireguard_server_private_key_path: str wireguard_server_private_key_path: str
wireguard_server_public_key_path: str wireguard_server_public_key_path: str
wireguard_acl_allowlist_windows: Tuple[int, ...] wireguard_acl_allowlist_windows: Tuple[int, ...]
wireguard_shell_port: int
wireguard_server_manager: Optional[Any] = None wireguard_server_manager: Optional[Any] = None
assembly_cache: Optional[Any] = None assembly_cache: Optional[Any] = None
@@ -151,18 +147,14 @@ def _build_engine_context(settings: EngineSettings, logger: logging.Logger) -> E
config=settings.as_dict(), config=settings.as_dict(),
api_groups=settings.api_groups, api_groups=settings.api_groups,
api_log_path=settings.api_log_file, api_log_path=settings.api_log_file,
reverse_tunnel_fixed_port=settings.reverse_tunnel_fixed_port, vpn_tunnel_log_path=settings.vpn_tunnel_log_file,
reverse_tunnel_port_range=settings.reverse_tunnel_port_range,
reverse_tunnel_idle_timeout_seconds=settings.reverse_tunnel_idle_timeout_seconds,
reverse_tunnel_grace_timeout_seconds=settings.reverse_tunnel_grace_timeout_seconds,
reverse_tunnel_heartbeat_seconds=settings.reverse_tunnel_heartbeat_seconds,
reverse_tunnel_log_path=settings.reverse_tunnel_log_file,
wireguard_port=settings.wireguard_port, wireguard_port=settings.wireguard_port,
wireguard_engine_virtual_ip=settings.wireguard_engine_virtual_ip, wireguard_engine_virtual_ip=settings.wireguard_engine_virtual_ip,
wireguard_peer_network=settings.wireguard_peer_network, wireguard_peer_network=settings.wireguard_peer_network,
wireguard_server_private_key_path=settings.wireguard_server_private_key_path, wireguard_server_private_key_path=settings.wireguard_server_private_key_path,
wireguard_server_public_key_path=settings.wireguard_server_public_key_path, wireguard_server_public_key_path=settings.wireguard_server_public_key_path,
wireguard_acl_allowlist_windows=settings.wireguard_acl_allowlist_windows, wireguard_acl_allowlist_windows=settings.wireguard_acl_allowlist_windows,
wireguard_shell_port=settings.wireguard_shell_port,
assembly_cache=None, assembly_cache=None,
) )
@@ -249,7 +241,7 @@ def create_app(config: Optional[Mapping[str, Any]] = None) -> Tuple[Flask, Socke
private_key_path=Path(context.wireguard_server_private_key_path), private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path), public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows), acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.reverse_tunnel_log_path), log_path=Path(context.vpn_tunnel_log_path),
) )
context.wireguard_server_manager = WireGuardServerManager(wg_config) context.wireguard_server_manager = WireGuardServerManager(wg_config)
except Exception: except Exception:
@@ -325,7 +317,7 @@ def register_engine_api(app: Flask, *, config: Optional[Mapping[str, Any]] = Non
private_key_path=Path(context.wireguard_server_private_key_path), private_key_path=Path(context.wireguard_server_private_key_path),
public_key_path=Path(context.wireguard_server_public_key_path), public_key_path=Path(context.wireguard_server_public_key_path),
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows), acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
log_path=Path(context.reverse_tunnel_log_path), log_path=Path(context.vpn_tunnel_log_path),
) )
context.wireguard_server_manager = WireGuardServerManager(wg_config) context.wireguard_server_manager = WireGuardServerManager(wg_config)
except Exception: except Exception:

View File

@@ -9,6 +9,8 @@
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields. # - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname. # - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device. # - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions. # - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition. # - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator. # - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
@@ -426,6 +428,131 @@ class DeviceManagementService:
return None return None
return None return None
def _parse_ports(self, raw: Any) -> List[int]:
ports: List[int] = []
if isinstance(raw, str):
parts = [part.strip() for part in raw.split(",") if part.strip()]
elif isinstance(raw, list):
parts = raw
else:
parts = []
for part in parts:
try:
value = int(part)
except Exception:
continue
if 1 <= value <= 65535:
ports.append(value)
return list(dict.fromkeys(ports))
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
os_text = (os_name or "").strip().lower()
if os_text and "windows" not in os_text:
baseline = {5900, 3478}
filtered = [p for p in ports if p in baseline]
return filtered or ports
return ports
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
default_ports: List[int] = []
shell_port = int(self.adapters.context.wireguard_shell_port)
try:
conn = self._db_conn()
cur = conn.cursor()
os_name = ""
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
if row and row[0]:
os_name = str(row[0])
except Exception:
os_name = ""
default_ports = self._default_vpn_ports(os_name)
cur.execute(
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return {
"agent_id": agent_id,
"allowed_ports": default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"source": "default",
}, 200
raw_ports = row[0] or ""
ports = []
try:
ports = json.loads(raw_ports) if raw_ports else []
except Exception:
ports = []
return {
"agent_id": agent_id,
"allowed_ports": ports or default_ports,
"default_ports": default_ports,
"shell_port": shell_port,
"updated_at": row[1],
"updated_by": row[2],
"source": "custom" if ports else "default",
}, 200
except Exception as exc:
self.logger.debug("Failed to load vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
agent_id = (agent_id or "").strip()
if not agent_id:
return {"error": "agent_id_required"}, 400
ports = self._parse_ports(payload.get("allowed_ports"))
if not ports:
return {"error": "allowed_ports_required"}, 400
user = self._current_user() or {}
updated_by = user.get("username") or ""
updated_at = datetime.now(timezone.utc).isoformat()
try:
conn = self._db_conn()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
VALUES (?, ?, ?, ?)
ON CONFLICT(agent_id) DO UPDATE SET
allowed_ports=excluded.allowed_ports,
updated_at=excluded.updated_at,
updated_by=excluded.updated_by
""",
(agent_id, json.dumps(ports), updated_at, updated_by),
)
conn.commit()
return {
"agent_id": agent_id,
"allowed_ports": ports,
"updated_at": updated_at,
"updated_by": updated_by,
"source": "custom",
}, 200
except Exception:
self.logger.debug("Failed to save vpn config", exc_info=True)
return {"error": "internal_error"}, 500
finally:
try:
conn.close()
except Exception:
pass
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]: def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
if not self._current_user(): if not self._current_user():
return {"error": "unauthorized"}, 401 return {"error": "unauthorized"}, 401
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
payload, status = service.set_device_description(hostname, description) payload, status = service.set_device_description(hostname, description)
return jsonify(payload), status return jsonify(payload), status
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
def _vpn_config(agent_id: str):
requirement = service._require_login()
if requirement:
payload, status = requirement
return jsonify(payload), status
if request.method == "GET":
payload, status = service.get_vpn_config(agent_id)
else:
body = request.get_json(silent=True) or {}
payload, status = service.set_vpn_config(agent_id, body)
return jsonify(payload), status
@blueprint.route("/api/device_list_views", methods=["GET"]) @blueprint.route("/api/device_list_views", methods=["GET"])
def _list_views(): def _list_views():
requirement = service._require_login() requirement = service._require_login()

View File

@@ -1,12 +1,14 @@
# ====================================================== # ======================================================
# Data\Engine\services\API\devices\tunnel.py # Data\Engine\services\API\devices\tunnel.py
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired). # Description: WireGuard VPN tunnel API (connect/status/disconnect).
# #
# API Endpoints (if applicable): # API Endpoints (if applicable):
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol. # - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
# ====================================================== # ======================================================
"""Reverse tunnel negotiation API (Engine side).""" """WireGuard VPN tunnel API (Engine side)."""
from __future__ import annotations from __future__ import annotations
import os import os
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
from flask import Blueprint, jsonify, request, session from flask import Blueprint, jsonify, request, session
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService from ...VPN import VpnTunnelService
if False: # pragma: no cover - import cycle hint for type checkers if False: # pragma: no cover - import cycle hint for type checkers
from .. import EngineServiceAdapters from .. import EngineServiceAdapters
def _current_user(app) -> Optional[Dict[str, str]]: def _current_user(app) -> Optional[Dict[str, str]]:
"""Resolve operator identity from session or signed token."""
username = session.get("username") username = session.get("username")
role = session.get("role") or "User" role = session.get("role") or "User"
if username: if username:
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
return None return None
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService: def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None) service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
if service is None: if service is None:
service = ReverseTunnelService( manager = getattr(adapters.context, "wireguard_server_manager", None)
adapters.context, if manager is None:
signer=getattr(adapters, "script_signer", None), raise RuntimeError("wireguard_manager_unavailable")
service = VpnTunnelService(
context=adapters.context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory, db_conn_factory=adapters.db_conn_factory,
socketio=getattr(adapters.context, "socketio", None), socketio=getattr(adapters.context, "socketio", None),
service_log=adapters.service_log,
signer=getattr(adapters, "script_signer", None),
) )
service.start() setattr(adapters, "_vpn_tunnel_service", service)
setattr(adapters, "_reverse_tunnel_service", service) setattr(adapters.context, "vpn_tunnel_service", service)
setattr(adapters.context, "reverse_tunnel_service", service)
return service return service
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None: def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
"""Register reverse tunnel negotiation endpoints.""" blueprint = Blueprint("vpn_tunnel", __name__)
logger = adapters.context.logger.getChild("vpn_tunnel.api")
blueprint = Blueprint("reverse_tunnel", __name__) @blueprint.route("/api/tunnel/connect", methods=["POST"])
service_log = adapters.service_log def connect_tunnel():
logger = adapters.context.logger.getChild("tunnel.api")
@blueprint.route("/api/tunnel/request", methods=["POST"])
def request_tunnel():
requirement = _require_login(app) requirement = _require_login(app)
if requirement: if requirement:
payload, status = requirement payload, status = requirement
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
body = request.get_json(silent=True) or {} body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id")) agent_id = _normalize_text(body.get("agent_id"))
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
if protocol == "ps" and domain == "ps":
domain = "remote-interactive-shell"
if not agent_id: if not agent_id:
return jsonify({"error": "agent_id_required"}), 400 return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
try: try:
lease = tunnel_service.request_lease( tunnel_service = _get_tunnel_service(adapters)
agent_id=agent_id, payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
protocol=protocol,
domain=domain,
operator_id=operator_id,
)
except RuntimeError as exc: except RuntimeError as exc:
message = str(exc) logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
if message.startswith("domain_limit:"): return jsonify({"error": "connect_failed"}), 500
domain_name = message.split(":", 1)[-1] if ":" in message else domain
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
if message == "port_pool_exhausted":
return jsonify({"error": "port_pool_exhausted"}), 503
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
return jsonify({"error": "lease_allocation_failed"}), 500
summary = tunnel_service.lease_summary(lease) return jsonify(payload), 200
summary["fixed_port"] = tunnel_service.fixed_port
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
service_log( @blueprint.route("/api/tunnel/status", methods=["GET"])
"reverse_tunnel", def tunnel_status():
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
)
return jsonify(summary), 200
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
def stop_tunnel(tunnel_id: str):
requirement = _require_login(app) requirement = _require_login(app)
if requirement: if requirement:
payload, status = requirement payload, status = requirement
return jsonify(payload), status return jsonify(payload), status
tunnel_id_norm = _normalize_text(tunnel_id) agent_id = _normalize_text(request.args.get("agent_id") or "")
if not tunnel_id_norm: if not agent_id:
return jsonify({"error": "tunnel_id_required"}), 400 return jsonify({"error": "agent_id_required"}), 400
tunnel_service = _get_tunnel_service(adapters)
payload = tunnel_service.status(agent_id)
if not payload:
return jsonify({"status": "down", "agent_id": agent_id}), 200
payload["status"] = "up"
bump = _normalize_text(request.args.get("bump") or "")
if bump:
tunnel_service.bump_activity(agent_id)
return jsonify(payload), 200
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
def tunnel_connect_status():
return tunnel_status()
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
def disconnect_tunnel():
requirement = _require_login(app)
if requirement:
payload, status = requirement
return jsonify(payload), status
body = request.get_json(silent=True) or {} body = request.get_json(silent=True) or {}
agent_id = _normalize_text(body.get("agent_id"))
tunnel_id = _normalize_text(body.get("tunnel_id"))
reason = _normalize_text(body.get("reason") or "operator_stop") reason = _normalize_text(body.get("reason") or "operator_stop")
tunnel_service = _get_tunnel_service(adapters) tunnel_service = _get_tunnel_service(adapters)
stopped = False stopped = False
try: if tunnel_id:
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason) stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
except Exception as exc: # pragma: no cover - defensive guard elif agent_id:
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True) stopped = tunnel_service.disconnect(agent_id, reason=reason)
else:
return jsonify({"error": "agent_id_required"}), 400
if not stopped: if not stopped:
return jsonify({"error": "not_found"}), 404 return jsonify({"error": "not_found"}), 404
service_log( return jsonify({"status": "stopped", "reason": reason}), 200
"reverse_tunnel",
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
)
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
app.register_blueprint(blueprint) app.register_blueprint(blueprint)

View File

@@ -8,4 +8,4 @@
"""VPN service helpers for the Engine runtime.""" """VPN service helpers for the Engine runtime."""
from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401 from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401
from .vpn_tunnel_service import VpnTunnelService # noqa: F401

View File

@@ -0,0 +1,473 @@
# ======================================================
# Data\Engine\services\VPN\vpn_tunnel_service.py
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
from __future__ import annotations
import base64
import ipaddress
import json
import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from .wireguard_server import WireGuardServerManager
@dataclass
class VpnSession:
tunnel_id: str
agent_id: str
virtual_ip: str
token: Dict[str, Any]
client_public_key: str
client_private_key: str
allowed_ports: Tuple[int, ...]
created_at: float
expires_at: float
last_activity: float
operator_ids: set[str] = field(default_factory=set)
firewall_rules: List[str] = field(default_factory=list)
activity_id: Optional[int] = None
hostname: Optional[str] = None
class VpnTunnelService:
def __init__(
self,
*,
context: Any,
wireguard_manager: WireGuardServerManager,
db_conn_factory,
socketio,
service_log,
signer: Optional[Any] = None,
idle_seconds: int = 900,
) -> None:
self.context = context
self.wg = wireguard_manager
self.db_conn_factory = db_conn_factory
self.socketio = socketio
self.service_log = service_log
self.signer = signer
self.logger = context.logger.getChild("vpn_tunnel")
self.activity_logger = self.wg.logger.getChild("device_activity")
self.idle_seconds = max(60, int(idle_seconds))
self._lock = threading.Lock()
self._sessions_by_agent: Dict[str, VpnSession] = {}
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
self._idle_thread.start()
def _idle_loop(self) -> None:
while True:
time.sleep(10)
now = time.time()
expired: List[VpnSession] = []
with self._lock:
for session in list(self._sessions_by_agent.values()):
if session.last_activity + self.idle_seconds <= now:
expired.append(session)
for session in expired:
self.disconnect(session.agent_id, reason="idle_timeout")
def _allocate_virtual_ip(self, agent_id: str) -> str:
existing = self._sessions_by_agent.get(agent_id)
if existing:
return existing.virtual_ip
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
for host in self._peer_network.hosts():
if host == self._engine_ip.ip:
continue
candidate = f"{host}/32"
if candidate not in used:
return candidate
raise RuntimeError("vpn_ip_pool_exhausted")
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
try:
conn = self.db_conn_factory()
cur = conn.cursor()
try:
cur.execute(
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
(agent_id,),
)
row = cur.fetchone()
os_name = str(row[0]).lower() if row and row[0] else ""
except Exception:
os_name = ""
if os_name and "windows" not in os_name:
baseline = {5900, 3478}
filtered = [p for p in default if p in baseline]
if filtered:
default = tuple(filtered)
cur.execute(
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
(agent_id,),
)
row = cur.fetchone()
if not row:
return default
raw = row[0] or ""
ports = json.loads(raw) if raw else []
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
ports = [p for p in ports if 1 <= p <= 65535]
return tuple(dict.fromkeys(ports)) or default
except Exception:
return default
finally:
try:
conn.close()
except Exception:
pass
def _generate_client_keys(self) -> Tuple[str, str]:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import x25519
key = x25519.X25519PrivateKey.generate()
priv = base64.b64encode(
key.private_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PrivateFormat.Raw,
encryption_algorithm=serialization.NoEncryption(),
)
).decode("ascii").strip()
pub = base64.b64encode(
key.public_key().public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
).decode("ascii").strip()
return priv, pub
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
payload = {
"agent_id": agent_id,
"tunnel_id": tunnel_id,
"port": self.context.wireguard_port,
"expires_at": expires_at,
"issued_at": time.time(),
}
if not self.signer:
return dict(payload)
token = dict(payload)
try:
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
signature = self.signer.sign(payload_bytes)
token["signature"] = base64.b64encode(signature).decode("ascii")
if hasattr(self.signer, "public_base64_spki"):
token["signing_key"] = self.signer.public_base64_spki()
token["sig_alg"] = "ed25519"
except Exception:
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
return token
def _refresh_listener(self) -> None:
peers: List[Mapping[str, object]] = []
for session in self._sessions_by_agent.values():
peer = self.wg.build_peer_profile(
session.agent_id,
session.virtual_ip,
allowed_ports=session.allowed_ports,
)
peer = dict(peer)
peer["public_key"] = session.client_public_key
peers.append(peer)
if not peers:
self.wg.stop_listener()
return
self.wg.start_listener(peers)
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
now = time.time()
with self._lock:
existing = self._sessions_by_agent.get(agent_id)
if existing:
if operator_id:
existing.operator_ids.add(operator_id)
existing.last_activity = now
return self._session_payload(existing)
tunnel_id = uuid.uuid4().hex
virtual_ip = self._allocate_virtual_ip(agent_id)
allowed_ports = self._load_allowed_ports(agent_id)
client_private, client_public = self._generate_client_keys()
token = self._issue_token(agent_id, tunnel_id, now + 300)
self.wg.require_orchestration_token(token)
session = VpnSession(
tunnel_id=tunnel_id,
agent_id=agent_id,
virtual_ip=virtual_ip,
token=token,
client_public_key=client_public,
client_private_key=client_private,
allowed_ports=allowed_ports,
created_at=now,
expires_at=now + 300,
last_activity=now,
)
if operator_id:
session.operator_ids.add(operator_id)
self._sessions_by_agent[agent_id] = session
self._sessions_by_tunnel[tunnel_id] = session
try:
self._refresh_listener()
peer = self.wg.build_peer_profile(
agent_id,
virtual_ip,
allowed_ports=allowed_ports,
)
rule_names = self.wg.apply_firewall_rules(peer)
session.firewall_rules = rule_names
except Exception:
with self._lock:
self._sessions_by_agent.pop(agent_id, None)
self._sessions_by_tunnel.pop(tunnel_id, None)
try:
self._refresh_listener()
except Exception:
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
raise
payload = self._session_payload(session)
self._emit_start(payload)
self._log_device_activity(session, event="start")
return payload
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return None
return self._session_payload(session, include_token=False)
def bump_activity(self, agent_id: str) -> None:
with self._lock:
session = self._sessions_by_agent.get(agent_id)
if not session:
return
session.last_activity = time.time()
try:
if self.socketio:
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_agent.pop(agent_id, None)
if not session:
return False
self._sessions_by_tunnel.pop(session.tunnel_id, None)
try:
self.wg.remove_firewall_rules(session.firewall_rules)
except Exception:
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
self._refresh_listener()
self._emit_stop(session, reason)
self._log_device_activity(session, event="stop", reason=reason)
return True
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
with self._lock:
session = self._sessions_by_tunnel.get(tunnel_id)
if not session:
return False
return self.disconnect(session.agent_id, reason=reason)
def _emit_start(self, payload: Mapping[str, Any]) -> None:
if not self.socketio:
return
try:
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
except Exception:
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
def _emit_stop(self, session: VpnSession, reason: str) -> None:
if not self.socketio:
return
try:
self.socketio.emit(
"vpn_tunnel_stop",
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
namespace="/",
)
except Exception:
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
if self.db_conn_factory is None:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
conn = None
try:
conn = self.db_conn_factory()
cur = conn.cursor()
hostname = session.hostname
if not hostname:
try:
cur.execute(
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
(session.agent_id,),
)
row = cur.fetchone()
if row and row[0]:
hostname = str(row[0]).strip()
session.hostname = hostname
except Exception:
hostname = None
if not hostname:
self.activity_logger.info(
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
event,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
)
return
now_ts = int(time.time())
script_name = "Reverse VPN Tunnel (WireGuard)"
if event == "start":
cur.execute(
"""
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
VALUES(?,?,?,?,?,?,?,?)
""",
(
hostname,
session.tunnel_id,
script_name,
"vpn_tunnel",
now_ts,
"Running",
"",
"",
),
)
session.activity_id = cur.lastrowid
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "created",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
session.activity_id or "-",
)
return
if session.activity_id:
status = "Completed" if event == "stop" else "Closed"
cur.execute(
"""
UPDATE activity_history
SET status=?,
stderr=COALESCE(stderr, '') || ?
WHERE id=?
""",
(
status,
f"\nreason: {reason}" if reason else "",
session.activity_id,
),
)
conn.commit()
if self.socketio:
try:
self.socketio.emit(
"device_activity_changed",
{
"hostname": hostname,
"activity_id": session.activity_id,
"change": "updated",
"source": "vpn_tunnel",
},
)
except Exception:
pass
self.activity_logger.info(
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
event,
hostname,
session.agent_id,
session.tunnel_id,
",".join(sorted(filter(None, session.operator_ids))) or "-",
reason or "-",
session.activity_id or "-",
)
except Exception:
self.activity_logger.debug(
"device_activity logging failed for tunnel_id=%s",
session.tunnel_id,
exc_info=True,
)
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
payload: Dict[str, Any] = {
"tunnel_id": session.tunnel_id,
"agent_id": session.agent_id,
"virtual_ip": session.virtual_ip,
"engine_virtual_ip": str(self._engine_ip.ip),
"allowed_ips": f"{self._engine_ip.ip}/32",
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
"server_public_key": self.wg.server_public_key,
"client_public_key": session.client_public_key,
"client_private_key": session.client_private_key,
"idle_seconds": self.idle_seconds,
"allowed_ports": list(session.allowed_ports),
"connected_operators": len([o for o in session.operator_ids if o]),
}
if include_token:
payload["token"] = session.token
return payload

View File

@@ -70,7 +70,7 @@ class WireGuardServerManager:
self.logger = _build_logger(config.log_path) self.logger = _build_logger(config.log_path)
self._ensure_cert_dir() self._ensure_cert_dir()
self.server_private_key, self.server_public_key = self._ensure_server_keys() self.server_private_key, self.server_public_key = self._ensure_server_keys()
self._service_name = "BorealisWireGuard" self._service_name = "borealis-wg"
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine" self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
def _ensure_cert_dir(self) -> None: def _ensure_cert_dir(self) -> None:
@@ -157,7 +157,7 @@ class WireGuardServerManager:
if not token: if not token:
raise ValueError("Missing orchestration token for WireGuard peer") raise ValueError("Missing orchestration token for WireGuard peer")
required_fields = ("agent_id", "tunnel_id", "expires_at") required_fields = ("agent_id", "tunnel_id", "expires_at", "port")
missing = [field for field in required_fields if field not in token or token[field] in (None, "")] missing = [field for field in required_fields if field not in token or token[field] in (None, "")]
if missing: if missing:
raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}") raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}")
@@ -167,6 +167,13 @@ class WireGuardServerManager:
except Exception: except Exception:
raise ValueError("Invalid orchestration token expiry") raise ValueError("Invalid orchestration token expiry")
try:
port = int(token["port"])
except Exception:
raise ValueError("Invalid orchestration token port")
if port != int(self.config.port):
raise ValueError("Orchestration token port mismatch")
now = time.time() now = time.time()
if expires_at <= now: if expires_at <= now:
raise ValueError("Orchestration token expired") raise ValueError("Orchestration token expired")
@@ -253,12 +260,14 @@ class WireGuardServerManager:
"host_only": True, "host_only": True,
} }
def apply_firewall_rules(self, peer: Mapping[str, object]) -> None: def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]:
"""Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh).""" """Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh)."""
rules = self.build_firewall_rules(peer) rules = self.build_firewall_rules(peer)
rule_names: List[str] = []
for idx, rule in enumerate(rules): for idx, rule in enumerate(rules):
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}" name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
protocol = str(rule.get("protocol") or "TCP").upper()
args = [ args = [
"netsh", "netsh",
"advfirewall", "advfirewall",
@@ -269,7 +278,7 @@ class WireGuardServerManager:
"dir=out", "dir=out",
"action=allow", "action=allow",
f"remoteip={rule.get('remote_address','')}", f"remoteip={rule.get('remote_address','')}",
f"protocol=TCP", f"protocol={protocol}",
f"localport={rule.get('local_port','')}", f"localport={rule.get('local_port','')}",
] ]
code, out, err = self._run_command(args) code, out, err = self._run_command(args)
@@ -277,6 +286,19 @@ class WireGuardServerManager:
self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err) self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err)
else: else:
self.logger.info("Applied firewall rule %s", name) self.logger.info("Applied firewall rule %s", name)
rule_names.append(name)
return rule_names
def remove_firewall_rules(self, rule_names: Sequence[str]) -> None:
for name in rule_names:
if not name:
continue
args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"]
code, out, err = self._run_command(args)
if code != 0:
self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err)
else:
self.logger.info("Removed firewall rule %s", name)
def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None: def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None:
"""Render a temporary WireGuard config and start the service.""" """Render a temporary WireGuard config and start the service."""
@@ -291,6 +313,9 @@ class WireGuardServerManager:
config_path.write_text(rendered, encoding="utf-8") config_path.write_text(rendered, encoding="utf-8")
self.logger.info("Rendered WireGuard config to %s", config_path) self.logger.info("Rendered WireGuard config to %s", config_path)
# Ensure old service is removed before re-installing.
self.stop_listener()
args = ["wireguard.exe", "/installtunnelservice", str(config_path)] args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
code, out, err = self._run_command(args) code, out, err = self._run_command(args)
if code != 0: if code != 0:
@@ -301,7 +326,7 @@ class WireGuardServerManager:
def stop_listener(self) -> None: def stop_listener(self) -> None:
"""Stop and remove the WireGuard tunnel service.""" """Stop and remove the WireGuard tunnel service."""
args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"] args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
code, out, err = self._run_command(args) code, out, err = self._run_command(args)
if code != 0: if code != 0:
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err) self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
@@ -323,15 +348,17 @@ class WireGuardServerManager:
port_list = [] port_list = []
for port in port_list: for port in port_list:
rules.append( for protocol in ("TCP", "UDP"):
{ rules.append(
"direction": "outbound", {
"remote_address": ip, "direction": "outbound",
"local_port": port, "remote_address": ip,
"action": "allow", "local_port": port,
"description": f"WireGuard engine->agent allow port {port}", "protocol": protocol,
} "action": "allow",
) "description": f"WireGuard engine->agent allow port {port}/{protocol}",
}
)
self.logger.info( self.logger.info(
"Prepared firewall rule plan for agent=%s rules=%s", "Prepared firewall rule plan for agent=%s rules=%s",

View File

@@ -1,3 +0,0 @@
"""Namespace package for reverse tunnel domain handlers (Engine side)."""
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]

View File

@@ -1,78 +0,0 @@
"""Placeholder Bash channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class BashChannelServer:
"""Stub Bash handler until the agent-side channel is implemented."""
protocol_name = "bash"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
# No-op placeholder; output collection for future Bash support.
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
# Placeholder: no agent-side Bash yet.
self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
# Placeholder: not implemented.
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["BashChannelServer"]

View File

@@ -1,139 +0,0 @@
"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain)."""
from __future__ import annotations
import json
from collections import deque
from typing import Any, Deque, Dict, List, Optional
# Mirror framing constants to avoid circular imports.
MSG_CHANNEL_OPEN = 0x03
MSG_CHANNEL_ACK = 0x04
MSG_DATA = 0x05
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
CLOSE_OK = 0
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class PowershellChannelServer:
"""Coordinate PowerShell channel frames over a TunnelBridge."""
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output: Deque[str] = deque()
self._close_reason: Optional[str] = None
self._close_code: Optional[int] = None
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
# ------------------------------------------------------------------ Agent frame handling
def handle_agent_frame(self, frame) -> None:
if frame.channel_id != self.channel_id:
return
if frame.msg_type == MSG_CHANNEL_ACK:
self._ack_received = True
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
return
if frame.msg_type == MSG_DATA:
try:
text = frame.payload.decode("utf-8", errors="replace")
except Exception:
text = ""
if text:
self._append_output(text)
return
if frame.msg_type == MSG_CLOSE:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
self._closed = True
self._close_code = payload.get("code") if isinstance(payload, dict) else None
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
self.logger.info(
"ps channel closed tunnel_id=%s code=%s reason=%s",
self.bridge.lease.tunnel_id,
self._close_code,
self._close_reason or "-",
)
# ------------------------------------------------------------------ Operator actions
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
if self._open_sent:
return
payload = json.dumps(
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
separators=(",", ":"),
).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
self._open_sent = True
self.logger.info(
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
if self._closed:
return
payload = data.encode("utf-8", errors="replace")
frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def send_resize(self, cols: int, rows: int) -> None:
if self._closed:
return
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
frame = self._close_frame_fn(self.channel_id, code, reason)
else:
frame = self._frame_cls(
msg_type=MSG_CLOSE,
channel_id=self.channel_id,
payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"),
)
self.bridge.operator_to_agent(frame)
# ------------------------------------------------------------------ Output polling
def drain_output(self) -> List[str]:
items: List[str] = []
while self._output:
items.append(self._output.popleft())
return items
def _append_output(self, text: str) -> None:
self._output.append(text)
# Cap buffer to avoid unbounded memory growth.
while len(self._output) > 500:
self._output.popleft()
# ------------------------------------------------------------------ Status helpers
def status(self) -> Dict[str, Any]:
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": self._close_reason,
"close_code": self._close_code,
}
__all__ = ["PowershellChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote interactive shell tunnels (Engine side)."""
from .Powershell import PowershellChannelServer
from .Bash import BashChannelServer
__all__ = ["PowershellChannelServer", "BashChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote interactive shells (PowerShell/Bash)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder SSH channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class SSHChannelServer:
protocol_name = "ssh"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["SSHChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WinRM channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WinRMChannelServer:
protocol_name = "winrm"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WinRMChannelServer"]

View File

@@ -1,6 +0,0 @@
"""Protocol handlers for remote management tunnels (Engine side)."""
from .SSH import SSHChannelServer
from .WinRM import WinRMChannelServer
__all__ = ["SSHChannelServer", "WinRMChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote management tunnels (SSH/WinRM)."""
__all__ = ["Protocols"]

View File

@@ -1,73 +0,0 @@
"""Placeholder RDP channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class RDPChannelServer:
protocol_name = "rdp"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["RDPChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder VNC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class VNCChannelServer:
protocol_name = "vnc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["VNCChannelServer"]

View File

@@ -1,73 +0,0 @@
"""Placeholder WebRTC channel server (Engine side)."""
from __future__ import annotations
from collections import deque
class WebRTCChannelServer:
protocol_name = "webrtc"
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output = deque()
self._frame_cls = frame_cls
self._close_frame_fn = close_frame_fn
def handle_agent_frame(self, frame) -> None:
try:
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
self._ack_received = True
except Exception:
return
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
self._open_sent = True
self.logger.info(
"webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
def send_resize(self, cols: int, rows: int) -> None:
return
def close(self, code: int = 6, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
if callable(self._close_frame_fn):
try:
frame = self._close_frame_fn(self.channel_id, code, reason)
self.bridge.operator_to_agent(frame)
except Exception:
pass
def drain_output(self):
items = []
while self._output:
items.append(self._output.popleft())
return items
def status(self):
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": None,
"close_code": None,
}
__all__ = ["WebRTCChannelServer"]

View File

@@ -1,7 +0,0 @@
"""Protocol handlers for remote video tunnels (Engine side)."""
from .WebRTC import WebRTCChannelServer
from .RDP import RDPChannelServer
from .VNC import VNCChannelServer
__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"]

View File

@@ -1,3 +0,0 @@
"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC)."""
__all__ = ["Protocols"]

View File

@@ -1,10 +0,0 @@
# ======================================================
# Data\Engine\services\WebSocket\Agent\__init__.py
# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding).
#
# API Endpoints (if applicable): None
# ======================================================
"""Agent-facing WebSocket services for the Engine runtime."""
__all__ = []

View File

@@ -1,6 +1,6 @@
# ====================================================== # ======================================================
# Data\Engine\services\WebSocket\__init__.py # Data\Engine\services\WebSocket\__init__.py
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications. # Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
# #
# API Endpoints (if applicable): None # API Endpoints (if applicable): None
# ====================================================== # ======================================================
@@ -8,24 +8,20 @@
"""WebSocket service registration for the Borealis Engine runtime.""" """WebSocket service registration for the Borealis Engine runtime."""
from __future__ import annotations from __future__ import annotations
import base64
import sqlite3 import sqlite3
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional
from flask import session, request from flask import request
from flask_socketio import SocketIO from flask_socketio import SocketIO
from ...database import initialise_engine_database from ...database import initialise_engine_database
from ...security import signing
from ...server import EngineContext from ...server import EngineContext
from .Agent.reverse_tunnel_orchestrator import ( from ..VPN import VpnTunnelService
ReverseTunnelService, from .vpn_shell import VpnShellBridge
TunnelBridge,
decode_frame,
TunnelFrame,
)
def _now_ts() -> int: def _now_ts() -> int:
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
"""Register Socket.IO event handlers for the Engine runtime.""" """Register Socket.IO event handlers for the Engine runtime."""
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
adapters = EngineRealtimeAdapters(context) adapters = EngineRealtimeAdapters(context)
logger = context.logger.getChild("realtime.quick_jobs") logger = context.logger.getChild("realtime.quick_jobs")
tunnel_service = getattr(context, "reverse_tunnel_service", None) shell_bridge = VpnShellBridge(socket_server, context)
if tunnel_service is None:
tunnel_service = ReverseTunnelService( def _get_tunnel_service() -> Optional[VpnTunnelService]:
context, service = getattr(context, "vpn_tunnel_service", None)
signer=None, if service is not None:
return service
manager = getattr(context, "wireguard_server_manager", None)
if manager is None:
return None
try:
signer = signing.load_signer()
except Exception:
signer = None
service = VpnTunnelService(
context=context,
wireguard_manager=manager,
db_conn_factory=adapters.db_conn_factory, db_conn_factory=adapters.db_conn_factory,
socketio=socket_server, socketio=socket_server,
service_log=adapters.service_log,
signer=signer,
) )
tunnel_service.start() setattr(context, "vpn_tunnel_service", service)
setattr(context, "reverse_tunnel_service", tunnel_service) return service
@socket_server.on("quick_job_result") @socket_server.on("quick_job_result")
def _handle_quick_job_result(data: Any) -> None: def _handle_quick_job_result(data: Any) -> None:
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
exc, exc,
) )
@socket_server.on("tunnel_bridge_attach") @socket_server.on("vpn_shell_open")
def _tunnel_bridge_attach(data: Any) -> Any: def _vpn_shell_open(data: Any) -> Dict[str, Any]:
"""Placeholder operator bridge attach handler (no data channel yet).""" agent_id = ""
if isinstance(data, dict):
agent_id = str(data.get("agent_id") or "").strip()
elif isinstance(data, str):
agent_id = data.strip()
if not agent_id:
return {"error": "agent_id_required"}
if not isinstance(data, dict): service = _get_tunnel_service()
return {"error": "invalid_payload"} if service is None:
return {"error": "vpn_service_unavailable"}
if not service.status(agent_id):
return {"error": "tunnel_down"}
tunnel_id = str(data.get("tunnel_id") or "").strip() session = shell_bridge.open_session(request.sid, agent_id)
operator_id = str(data.get("operator_id") or "").strip() or None if session is None:
if not tunnel_id: return {"error": "shell_connect_failed"}
return {"error": "tunnel_id_required"} service.bump_activity(agent_id)
return {"status": "ok"}
try: @socket_server.on("vpn_shell_send")
tunnel_service.operator_attach(tunnel_id, operator_id) def _vpn_shell_send(data: Any) -> Dict[str, Any]:
except ValueError as exc: payload = None
return {"error": str(exc)} if isinstance(data, dict):
except Exception as exc: # pragma: no cover - defensive guard payload = data.get("data")
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "bridge_attach_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
def _encode_frame(frame: TunnelFrame) -> str:
return base64.b64encode(frame.encode()).decode("ascii")
def _decode_frame_payload(raw: Any) -> TunnelFrame:
if isinstance(raw, str):
try:
raw_bytes = base64.b64decode(raw)
except Exception:
raise ValueError("invalid_frame")
elif isinstance(raw, (bytes, bytearray)):
raw_bytes = bytes(raw)
else: else:
raise ValueError("invalid_frame") payload = data
return decode_frame(raw_bytes) if payload is None:
@socket_server.on("tunnel_operator_send")
def _tunnel_operator_send(data: Any) -> Any:
"""Operator -> agent frame enqueue (placeholder queue)."""
if not isinstance(data, dict):
return {"error": "invalid_payload"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
frame_raw = data.get("frame")
if not tunnel_id or frame_raw is None:
return {"error": "tunnel_id_and_frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception as exc:
return {"error": str(exc)}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("tunnel_operator_poll")
def _tunnel_operator_poll(data: Any) -> Any:
"""Operator polls queued frames from agent."""
tunnel_id = ""
if isinstance(data, dict):
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
# WebUI operator bridge namespace for browser clients
tunnel_namespace = "/tunnel"
_operator_sessions: Dict[str, str] = {}
def _current_operator() -> Optional[str]:
username = session.get("username")
if username:
return str(username)
auth_header = (request.headers.get("Authorization") or "").strip()
token = None
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
token = request.cookies.get("borealis_auth")
return token or None
@socket_server.on("join", namespace=tunnel_namespace)
def _ws_tunnel_join(data: Any) -> Any:
if not isinstance(data, dict):
return {"error": "invalid_payload"}
operator_id = _current_operator()
if not operator_id:
return {"error": "unauthorized"}
tunnel_id = str(data.get("tunnel_id") or "").strip()
if not tunnel_id:
return {"error": "tunnel_id_required"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
try:
tunnel_service.operator_attach(tunnel_id, operator_id)
except Exception as exc:
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "attach_failed"}
sid = request.sid
_operator_sessions[sid] = tunnel_id
return {"status": "ok", "tunnel_id": tunnel_id}
@socket_server.on("send", namespace=tunnel_namespace)
def _ws_tunnel_send(data: Any) -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
if not isinstance(data, dict):
return {"error": "invalid_payload"}
frame_raw = data.get("frame")
if frame_raw is None:
return {"error": "frame_required"}
try:
frame = _decode_frame_payload(frame_raw)
except Exception:
return {"error": "invalid_frame"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
bridge.operator_to_agent(frame)
return {"status": "ok"}
@socket_server.on("poll", namespace=tunnel_namespace)
def _ws_tunnel_poll() -> Any:
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return {"error": "not_joined"}
bridge = tunnel_service.get_bridge(tunnel_id)
if bridge is None:
return {"error": "unknown_tunnel"}
frames = []
while True:
frame = bridge.next_for_operator()
if frame is None:
break
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_protocol_server(tunnel_id)
if server is None or not hasattr(server, "open_channel"):
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
return {"error": "payload_required"} return {"error": "payload_required"}
text = data shell_bridge.send(request.sid, str(payload))
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
return {"status": "ok"} return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace) @socket_server.on("vpn_shell_close")
def _ws_ps_resize(data: Any) -> Any: def _vpn_shell_close() -> Dict[str, Any]:
server, tunnel_id, error = _require_ps_server() shell_bridge.close(request.sid)
if server is None: return {"status": "ok"}
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace) @socket_server.on("disconnect")
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused def _ws_disconnect() -> None:
server, tunnel_id, error = _require_ps_server() shell_bridge.close(request.sid)
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid
tunnel_id = _operator_sessions.pop(sid, None)
if tunnel_id and tunnel_id not in _operator_sessions.values():
try:
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
except Exception as exc:
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)

View File

@@ -0,0 +1,127 @@
# ======================================================
# Data\Engine\services\WebSocket\vpn_shell.py
# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard.
#
# API Endpoints (if applicable): None
# ======================================================
"""WireGuard VPN PowerShell bridge (Engine side)."""
from __future__ import annotations
import base64
import json
import socket
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
def _b64encode(data: bytes) -> str:
return base64.b64encode(data).decode("ascii").strip()
def _b64decode(value: str) -> bytes:
return base64.b64decode(value.encode("ascii"))
@dataclass
class ShellSession:
sid: str
agent_id: str
socketio: Any
tcp: socket.socket
_reader: Optional[threading.Thread] = None
def start_reader(self) -> None:
t = threading.Thread(target=self._read_loop, daemon=True)
t.start()
self._reader = t
def _read_loop(self) -> None:
buffer = b""
try:
while True:
data = self.tcp.recv(4096)
if not data:
break
buffer += data
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
if not line:
continue
try:
msg = json.loads(line.decode("utf-8"))
except Exception:
continue
if msg.get("type") == "stdout":
payload = msg.get("data") or ""
try:
decoded = _b64decode(str(payload)).decode("utf-8", errors="replace")
except Exception:
decoded = ""
self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid)
finally:
self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid)
try:
self.tcp.close()
except Exception:
pass
def send(self, payload: str) -> None:
data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))})
self.tcp.sendall(data.encode("utf-8") + b"\n")
def close(self) -> None:
try:
data = json.dumps({"type": "close"})
self.tcp.sendall(data.encode("utf-8") + b"\n")
except Exception:
pass
try:
self.tcp.close()
except Exception:
pass
class VpnShellBridge:
def __init__(self, socketio, context) -> None:
self.socketio = socketio
self.context = context
self._sessions: Dict[str, ShellSession] = {}
self.logger = context.logger.getChild("vpn_shell")
def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]:
service = getattr(self.context, "vpn_tunnel_service", None)
if service is None:
return None
status = service.status(agent_id)
if not status:
return None
host = str(status.get("virtual_ip") or "").split("/")[0]
port = int(self.context.wireguard_shell_port)
try:
tcp = socket.create_connection((host, port), timeout=5)
except Exception:
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
return None
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
self._sessions[sid] = session
session.start_reader()
return session
def send(self, sid: str, payload: str) -> None:
session = self._sessions.get(sid)
if not session:
return
session.send(payload)
service = getattr(self.context, "vpn_tunnel_service", None)
if service:
service.bump_activity(session.agent_id)
def close(self, sid: str) -> None:
session = self._sessions.pop(sid, None)
if not session:
return
session.close()

View File

@@ -8,13 +8,17 @@ import {
Tab, Tab,
Typography, Typography,
Button, Button,
Switch,
Chip,
Divider,
Menu, Menu,
MenuItem, MenuItem,
TextField, TextField,
Dialog, Dialog,
DialogTitle, DialogTitle,
DialogContent, DialogContent,
DialogActions DialogActions,
LinearProgress
} from "@mui/material"; } from "@mui/material";
import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined"; import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined";
import StorageRoundedIcon from "@mui/icons-material/StorageRounded"; import StorageRoundedIcon from "@mui/icons-material/StorageRounded";
@@ -23,6 +27,7 @@ import LanRoundedIcon from "@mui/icons-material/LanRounded";
import AppsRoundedIcon from "@mui/icons-material/AppsRounded"; import AppsRoundedIcon from "@mui/icons-material/AppsRounded";
import ListAltRoundedIcon from "@mui/icons-material/ListAltRounded"; import ListAltRoundedIcon from "@mui/icons-material/ListAltRounded";
import TerminalRoundedIcon from "@mui/icons-material/TerminalRounded"; import TerminalRoundedIcon from "@mui/icons-material/TerminalRounded";
import TuneRoundedIcon from "@mui/icons-material/TuneRounded";
import SpeedRoundedIcon from "@mui/icons-material/SpeedRounded"; import SpeedRoundedIcon from "@mui/icons-material/SpeedRounded";
import DeveloperBoardRoundedIcon from "@mui/icons-material/DeveloperBoardRounded"; import DeveloperBoardRoundedIcon from "@mui/icons-material/DeveloperBoardRounded";
import MoreHorizIcon from "@mui/icons-material/MoreHoriz"; import MoreHorizIcon from "@mui/icons-material/MoreHoriz";
@@ -69,14 +74,51 @@ const SECTION_HEIGHTS = {
network: 260, network: 260,
}; };
const buildVpnGroups = (shellPort) => {
const normalizedShell = Number(shellPort) || 47001;
return [
{
key: "shell",
label: "Borealis PowerShell",
description: "Web terminal access over the VPN tunnel.",
ports: [normalizedShell],
},
{
key: "rdp",
label: "RDP",
description: "Remote Desktop (TCP 3389).",
ports: [3389],
},
{
key: "winrm",
label: "WinRM",
description: "PowerShell/WinRM management (TCP 5985/5986).",
ports: [5985, 5986],
},
{
key: "vnc",
label: "VNC",
description: "Remote desktop streaming (TCP 5900).",
ports: [5900],
},
{
key: "webrtc",
label: "WebRTC",
description: "Real-time comms (UDP 3478).",
ports: [3478],
},
];
};
const TOP_TABS = [ const TOP_TABS = [
{ label: "Device Summary", icon: InfoOutlinedIcon }, { key: "summary", label: "Device Summary", icon: InfoOutlinedIcon },
{ label: "Storage", icon: StorageRoundedIcon }, { key: "storage", label: "Storage", icon: StorageRoundedIcon },
{ label: "Memory", icon: MemoryRoundedIcon }, { key: "memory", label: "Memory", icon: MemoryRoundedIcon },
{ label: "Network", icon: LanRoundedIcon }, { key: "network", label: "Network", icon: LanRoundedIcon },
{ label: "Installed Software", icon: AppsRoundedIcon }, { key: "software", label: "Installed Software", icon: AppsRoundedIcon },
{ label: "Activity History", icon: ListAltRoundedIcon }, { key: "activity", label: "Activity History", icon: ListAltRoundedIcon },
{ label: "Remote Shell", icon: TerminalRoundedIcon }, { key: "advanced", label: "Advanced Config", icon: TuneRoundedIcon },
{ key: "shell", label: "Remote Shell", icon: TerminalRoundedIcon },
]; ];
const myTheme = themeQuartz.withParams({ const myTheme = themeQuartz.withParams({
@@ -286,6 +328,15 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const [menuAnchor, setMenuAnchor] = useState(null); const [menuAnchor, setMenuAnchor] = useState(null);
const [clearDialogOpen, setClearDialogOpen] = useState(false); const [clearDialogOpen, setClearDialogOpen] = useState(false);
const [assemblyNameMap, setAssemblyNameMap] = useState({}); const [assemblyNameMap, setAssemblyNameMap] = useState({});
const [vpnLoading, setVpnLoading] = useState(false);
const [vpnSaving, setVpnSaving] = useState(false);
const [vpnError, setVpnError] = useState("");
const [vpnSource, setVpnSource] = useState("default");
const [vpnToggles, setVpnToggles] = useState({});
const [vpnCustomPorts, setVpnCustomPorts] = useState([]);
const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]);
const [vpnShellPort, setVpnShellPort] = useState(47001);
const [vpnLoadedFor, setVpnLoadedFor] = useState("");
// Snapshotted status for the lifetime of this page // Snapshotted status for the lifetime of this page
const [lockedStatus, setLockedStatus] = useState(() => { const [lockedStatus, setLockedStatus] = useState(() => {
// Prefer status provided by the device list row if available // Prefer status provided by the device list row if available
@@ -655,6 +706,104 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
}; };
}, [activityHostname, loadHistory]); }, [activityHostname, loadHistory]);
const applyVpnPorts = useCallback((ports, defaults, shellPort, source) => {
const normalized = Array.isArray(ports) ? ports : [];
const normalizedDefaults = Array.isArray(defaults) ? defaults : [];
const numericPorts = normalized
.map((p) => Number(p))
.filter((p) => Number.isFinite(p) && p > 0);
const numericDefaults = normalizedDefaults
.map((p) => Number(p))
.filter((p) => Number.isFinite(p) && p > 0);
const effectiveShell = Number(shellPort) || 47001;
const groups = buildVpnGroups(effectiveShell);
const knownPorts = new Set(groups.flatMap((group) => group.ports));
const allowedSet = new Set(numericPorts);
const nextToggles = {};
groups.forEach((group) => {
nextToggles[group.key] = group.ports.every((port) => allowedSet.has(port));
});
const customPorts = numericPorts.filter((port) => !knownPorts.has(port));
setVpnShellPort(effectiveShell);
setVpnDefaultPorts(numericDefaults);
setVpnCustomPorts(customPorts);
setVpnToggles(nextToggles);
setVpnSource(source || "default");
}, []);
const loadVpnConfig = useCallback(async () => {
if (!vpnAgentId) return;
setVpnLoading(true);
setVpnError("");
setVpnLoadedFor(vpnAgentId);
try {
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`);
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : [];
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : [];
const shellPort = data?.shell_port;
applyVpnPorts(allowedPorts.length ? allowedPorts : defaultPorts, defaultPorts, shellPort, data?.source);
setVpnLoadedFor(vpnAgentId);
} catch (err) {
setVpnError(String(err.message || err));
} finally {
setVpnLoading(false);
}
}, [applyVpnPorts, vpnAgentId]);
const saveVpnConfig = useCallback(async () => {
if (!vpnAgentId) return;
const ports = [];
vpnPortGroups.forEach((group) => {
if (vpnToggles[group.key]) {
ports.push(...group.ports);
}
});
vpnCustomPorts.forEach((port) => ports.push(port));
const uniquePorts = Array.from(new Set(ports)).filter((p) => p > 0);
if (!uniquePorts.length) {
setVpnError("Enable at least one port before saving.");
return;
}
setVpnSaving(true);
setVpnError("");
try {
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ allowed_ports: uniquePorts }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : uniquePorts;
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : vpnDefaultPorts;
applyVpnPorts(allowedPorts, defaultPorts, data?.shell_port || vpnShellPort, data?.source || "custom");
} catch (err) {
setVpnError(String(err.message || err));
} finally {
setVpnSaving(false);
}
}, [applyVpnPorts, vpnAgentId, vpnCustomPorts, vpnDefaultPorts, vpnPortGroups, vpnShellPort, vpnToggles]);
const resetVpnConfig = useCallback(() => {
if (!vpnDefaultPorts.length) {
setVpnError("No default ports available to reset.");
return;
}
setVpnError("");
applyVpnPorts(vpnDefaultPorts, vpnDefaultPorts, vpnShellPort, "default");
}, [applyVpnPorts, vpnDefaultPorts, vpnShellPort]);
useEffect(() => {
const advancedIndex = TOP_TABS.findIndex((item) => item.key === "advanced");
if (advancedIndex < 0) return;
if (tab !== advancedIndex) return;
if (!vpnAgentId) return;
if (vpnLoadedFor === vpnAgentId) return;
loadVpnConfig();
}, [loadVpnConfig, tab, vpnAgentId, vpnLoadedFor]);
// No explicit live recap tab; recaps are recorded into Activity History // No explicit live recap tab; recaps are recorded into Activity History
const clearHistory = async () => { const clearHistory = async () => {
@@ -739,6 +888,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
); );
const summary = details.summary || {}; const summary = details.summary || {};
const vpnAgentId = useMemo(() => {
return (
meta.agentId ||
summary.agent_id ||
agent?.agent_id ||
agent?.id ||
device?.agent_id ||
device?.agent_guid ||
device?.id ||
""
);
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
const tunnelDevice = useMemo( const tunnelDevice = useMemo(
() => ({ () => ({
...(device || {}), ...(device || {}),
@@ -876,7 +1038,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
const formatScriptType = useCallback((raw) => { const formatScriptType = useCallback((raw) => {
const value = String(raw || "").toLowerCase(); const value = String(raw || "").toLowerCase();
if (value === "ansible") return "Ansible Playbook"; if (value === "ansible") return "Ansible Playbook";
if (value === "reverse_tunnel") return "Reverse Tunnel"; if (value === "reverse_tunnel" || value === "vpn_tunnel") return "Reverse VPN Tunnel";
return "Script"; return "Script";
}, []); }, []);
@@ -1368,6 +1530,150 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
</Box> </Box>
); );
const handleVpnToggle = useCallback((key, checked) => {
setVpnToggles((prev) => ({ ...(prev || {}), [key]: checked }));
setVpnSource("custom");
}, []);
const renderAdvancedConfigTab = () => {
const sourceLabel = vpnSource === "custom" ? "Custom overrides" : "Defaults";
const showProgress = vpnLoading || vpnSaving;
return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 2, flexGrow: 1, minHeight: 0 }}>
<Box
sx={{
borderRadius: 3,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background:
"linear-gradient(160deg, rgba(8,12,24,0.94), rgba(10,16,30,0.9)), radial-gradient(circle at 20% 10%, rgba(125,211,252,0.08), transparent 40%)",
boxShadow: "0 25px 80px rgba(2,6,23,0.65)",
p: { xs: 2, md: 3 },
}}
>
{showProgress ? <LinearProgress color="info" sx={{ height: 3, mb: 2 }} /> : null}
<Stack direction={{ xs: "column", md: "row" }} spacing={2} alignItems={{ xs: "flex-start", md: "center" }}>
<Box sx={{ flexGrow: 1 }}>
<Typography variant="h6" sx={{ color: MAGIC_UI.textBright, fontWeight: 700 }}>
Reverse VPN Tunnel - Allowed Ports
</Typography>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.5 }}>
Toggle which services the Engine can reach over the WireGuard tunnel for this device.
</Typography>
</Box>
<Chip
label={sourceLabel}
sx={{
borderRadius: 999,
fontWeight: 600,
letterSpacing: 0.2,
color: vpnSource === "custom" ? MAGIC_UI.accentA : MAGIC_UI.textMuted,
border: `1px solid ${MAGIC_UI.panelBorder}`,
backgroundColor: "rgba(8,12,24,0.75)",
}}
/>
</Stack>
<Divider sx={{ my: 2, borderColor: "rgba(148,163,184,0.2)" }} />
<Stack spacing={1.5}>
{vpnPortGroups.map((group) => (
<Box
key={group.key}
sx={{
display: "flex",
alignItems: { xs: "flex-start", md: "center" },
justifyContent: "space-between",
gap: 2,
p: 2,
borderRadius: 2,
border: `1px solid ${MAGIC_UI.panelBorder}`,
background: "rgba(6,10,20,0.7)",
}}
>
<Box sx={{ flexGrow: 1 }}>
<Typography sx={{ color: MAGIC_UI.textBright, fontWeight: 600 }}>
{group.label}
</Typography>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.35 }}>
{group.description}
</Typography>
<Stack direction="row" spacing={0.75} sx={{ mt: 0.8, flexWrap: "wrap" }}>
{group.ports.map((port) => (
<Chip
key={`${group.key}-${port}`}
label={`TCP ${port}`}
size="small"
sx={{
borderRadius: 999,
backgroundColor: "rgba(15,23,42,0.65)",
color: MAGIC_UI.textMuted,
border: `1px solid rgba(148,163,184,0.25)`,
}}
/>
))}
</Stack>
</Box>
<Switch
checked={Boolean(vpnToggles[group.key])}
onChange={(event) => handleVpnToggle(group.key, event.target.checked)}
color="info"
disabled={vpnLoading || vpnSaving}
/>
</Box>
))}
</Stack>
{vpnCustomPorts.length ? (
<Box sx={{ mt: 2 }}>
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
Custom ports preserved: {vpnCustomPorts.join(", ")}
</Typography>
</Box>
) : null}
{vpnError ? (
<Typography variant="body2" sx={{ color: "#ff7b89", mt: 1 }}>
{vpnError}
</Typography>
) : null}
<Stack direction="row" spacing={1.25} sx={{ mt: 2 }}>
<Button
size="small"
disabled={!vpnAgentId || vpnSaving || vpnLoading}
onClick={saveVpnConfig}
sx={{
backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)",
color: "#0b1220",
borderRadius: 999,
textTransform: "none",
px: 2.4,
"&:hover": {
backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)",
},
}}
>
Save Config
</Button>
<Button
size="small"
disabled={!vpnDefaultPorts.length || vpnSaving || vpnLoading}
onClick={resetVpnConfig}
sx={{
borderRadius: 999,
textTransform: "none",
px: 2.4,
color: MAGIC_UI.textBright,
border: `1px solid ${MAGIC_UI.panelBorder}`,
backgroundColor: "rgba(8,12,24,0.6)",
"&:hover": {
backgroundColor: "rgba(12,18,35,0.8)",
},
}}
>
Reset Defaults
</Button>
</Stack>
</Box>
</Box>
);
};
const memoryRows = useMemo( const memoryRows = useMemo(
() => () =>
(details.memory || []).map((m, idx) => ({ (details.memory || []).map((m, idx) => ({
@@ -1618,6 +1924,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
renderNetworkTab, renderNetworkTab,
renderSoftware, renderSoftware,
renderHistory, renderHistory,
renderAdvancedConfigTab,
renderRemoteShellTab, renderRemoteShellTab,
]; ];
const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)(); const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)();
@@ -1742,7 +2049,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
> >
{TOP_TABS.map((tabDef) => ( {TOP_TABS.map((tabDef) => (
<Tab <Tab
key={tabDef.label} key={tabDef.key || tabDef.label}
label={tabDef.label} label={tabDef.label}
icon={<tabDef.icon sx={{ fontSize: 18 }} />} icon={<tabDef.icon sx={{ fontSize: 18 }} />}
iconPosition="start" iconPosition="start"

View File

@@ -5,17 +5,17 @@ import {
Button, Button,
Stack, Stack,
TextField, TextField,
MenuItem,
IconButton, IconButton,
Tooltip, Tooltip,
LinearProgress, LinearProgress,
Chip,
} from "@mui/material"; } from "@mui/material";
import { import {
PlayArrowRounded as PlayIcon, PlayArrowRounded as PlayIcon,
StopRounded as StopIcon, StopRounded as StopIcon,
ContentCopy as CopyIcon, ContentCopy as CopyIcon,
RefreshRounded as RefreshIcon, RefreshRounded as RefreshIcon,
LanRounded as PortIcon, LanRounded as IpIcon,
LinkRounded as LinkIcon, LinkRounded as LinkIcon,
} from "@mui/icons-material"; } from "@mui/icons-material";
import { io } from "socket.io-client"; import { io } from "socket.io-client";
@@ -24,18 +24,7 @@ import "prismjs/components/prism-powershell";
import "prismjs/themes/prism-okaidia.css"; import "prismjs/themes/prism-okaidia.css";
import Editor from "react-simple-code-editor"; import Editor from "react-simple-code-editor";
// Console diagnostics for troubleshooting the connect/disconnect flow.
const debugLog = (...args) => {
try {
// eslint-disable-next-line no-console
console.error("[ReverseTunnel][PS]", ...args);
} catch {
// ignore
}
};
const MAGIC_UI = { const MAGIC_UI = {
panelBg: "rgba(7,11,24,0.92)",
panelBorder: "rgba(148, 163, 184, 0.35)", panelBorder: "rgba(148, 163, 184, 0.35)",
textMuted: "#94a3b8", textMuted: "#94a3b8",
textBright: "#e2e8f0", textBright: "#e2e8f0",
@@ -56,13 +45,25 @@ const gradientButtonSx = {
}, },
}; };
const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32)
const MSG_CLOSE = 0x08;
const CLOSE_AGENT_SHUTDOWN = 6;
const fontFamilyMono = const fontFamilyMono =
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace'; 'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
const emitAsync = (socket, event, payload, timeoutMs = 4000) =>
new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
function normalizeText(value) { function normalizeText(value) {
if (value == null) return ""; if (value == null) return "";
try { try {
@@ -72,28 +73,6 @@ function normalizeText(value) {
} }
} }
function base64FromBytes(bytes) {
let binary = "";
bytes.forEach((b) => {
binary += String.fromCharCode(b);
});
return btoa(binary);
}
function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") {
const payload = new TextEncoder().encode(JSON.stringify({ code, reason }));
const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length);
const view = new DataView(buffer);
view.setUint8(0, 1); // version
view.setUint8(1, MSG_CLOSE);
view.setUint8(2, 0); // flags
view.setUint8(3, 0); // reserved
view.setUint32(4, channelId >>> 0, true);
view.setUint32(8, payload.length >>> 0, true);
new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload);
return base64FromBytes(new Uint8Array(buffer));
}
function highlightPs(code) { function highlightPs(code) {
try { try {
return Prism.highlight(code || "", Prism.languages.powershell, "powershell"); return Prism.highlight(code || "", Prism.languages.powershell, "powershell");
@@ -102,52 +81,18 @@ function highlightPs(code) {
} }
} }
const INITIAL_MILESTONES = {
tunnelReady: false,
operatorAttached: false,
shellEstablished: false,
};
const INITIAL_STATUS_CHAIN = ["Offline"];
export default function ReverseTunnelPowershell({ device }) { export default function ReverseTunnelPowershell({ device }) {
const [connectionType, setConnectionType] = useState("ps");
const [tunnel, setTunnel] = useState(null);
const [sessionState, setSessionState] = useState("idle"); const [sessionState, setSessionState] = useState("idle");
const [, setStatusMessage] = useState(""); const [shellState, setShellState] = useState("idle");
const [, setStatusSeverity] = useState("info"); const [tunnel, setTunnel] = useState(null);
const [output, setOutput] = useState(""); const [output, setOutput] = useState("");
const [input, setInput] = useState(""); const [input, setInput] = useState("");
const [statusMessage, setStatusMessage] = useState("");
const [copyFlash, setCopyFlash] = useState(false); const [copyFlash, setCopyFlash] = useState(false);
const [, setPolling] = useState(false); const [loading, setLoading] = useState(false);
const [psStatus, setPsStatus] = useState({});
const [milestones, setMilestones] = useState(() => ({ ...INITIAL_MILESTONES }));
const [tunnelSteps, setTunnelSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const [websocketSteps, setWebsocketSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const [shellSteps, setShellSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
const socketRef = useRef(null); const socketRef = useRef(null);
const pollTimerRef = useRef(null); const localSocketRef = useRef(false);
const resizeTimerRef = useRef(null);
const terminalRef = useRef(null); const terminalRef = useRef(null);
const joinRetryRef = useRef(null);
const joinAttemptsRef = useRef(0);
const tunnelRef = useRef(null);
const shellFlagsRef = useRef({ openSent: false, ack: false });
const DOMAIN_REMOTE_SHELL = "remote-interactive-shell";
useEffect(() => {
debugLog("component mount", { hostname: device?.hostname, agentId });
return () => debugLog("component unmount");
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
const hostname = useMemo(() => {
return (
normalizeText(device?.hostname) ||
normalizeText(device?.summary?.hostname) ||
normalizeText(device?.agent_hostname) ||
""
);
}, [device]);
const agentId = useMemo(() => { const agentId = useMemo(() => {
return ( return (
@@ -162,78 +107,18 @@ export default function ReverseTunnelPowershell({ device }) {
); );
}, [device]); }, [device]);
const appendStatus = useCallback((setter, label) => { const ensureSocket = useCallback(() => {
if (!label) return; if (socketRef.current) return socketRef.current;
setter((prev) => { const existing = typeof window !== "undefined" ? window.BorealisSocket : null;
const next = [...prev, label]; if (existing) {
const cap = 6; socketRef.current = existing;
return next.length > cap ? next.slice(next.length - cap) : next; localSocketRef.current = false;
}); return existing;
}, []);
const resetState = useCallback(() => {
debugLog("resetState invoked");
setTunnel(null);
setSessionState("idle");
setStatusMessage("");
setStatusSeverity("info");
setOutput("");
setInput("");
setPsStatus({});
setMilestones({ ...INITIAL_MILESTONES });
setTunnelSteps([...INITIAL_STATUS_CHAIN]);
setWebsocketSteps([...INITIAL_STATUS_CHAIN]);
setShellSteps([...INITIAL_STATUS_CHAIN]);
shellFlagsRef.current = { openSent: false, ack: false };
}, []);
useEffect(() => {
tunnelRef.current = tunnel?.tunnel_id || null;
}, [tunnel?.tunnel_id]);
const disconnectSocket = useCallback(() => {
const socket = socketRef.current;
if (socket) {
socket.off();
socket.disconnect();
} }
socketRef.current = null; const socket = io(window.location.origin, { transports: ["websocket"] });
}, []); socketRef.current = socket;
localSocketRef.current = true;
const stopPolling = useCallback(() => { return socket;
if (pollTimerRef.current) {
clearTimeout(pollTimerRef.current);
pollTimerRef.current = null;
}
setPolling(false);
}, []);
const stopTunnel = useCallback(async (reason = "operator_disconnect", tunnelIdOverride = null) => {
const tunnelId = tunnelIdOverride || tunnelRef.current;
if (!tunnelId) return;
try {
await fetch(`/api/tunnel/${tunnelId}`, {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ reason }),
});
} catch (err) {
// best-effort; socket close frame acts as fallback
}
}, []);
useEffect(() => {
return () => {
debugLog("cleanup on unmount", { tunnelId: tunnelRef.current });
stopPolling();
disconnectSocket();
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
stopTunnel("component_unmount", tunnelRef.current);
};
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []); }, []);
const appendOutput = useCallback((text) => { const appendOutput = useCallback((text) => {
@@ -257,6 +142,137 @@ export default function ReverseTunnelPowershell({ device }) {
scrollToBottom(); scrollToBottom();
}, [output, scrollToBottom]); }, [output, scrollToBottom]);
const stopTunnel = useCallback(
async (reason = "operator_disconnect") => {
if (!agentId) return;
try {
await fetch("/api/tunnel/disconnect", {
method: "DELETE",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }),
});
} catch {
// best-effort
}
},
[agentId, tunnel?.tunnel_id]
);
const closeShell = useCallback(async () => {
const socket = ensureSocket();
await emitAsync(socket, "vpn_shell_close", {});
}, [ensureSocket]);
const handleDisconnect = useCallback(async () => {
setLoading(true);
setStatusMessage("");
try {
await closeShell();
await stopTunnel("operator_disconnect");
} finally {
setTunnel(null);
setShellState("closed");
setSessionState("idle");
setLoading(false);
}
}, [closeShell, stopTunnel]);
useEffect(() => {
const socket = ensureSocket();
const handleDisconnectEvent = () => {
if (sessionState === "connected") {
setShellState("closed");
setSessionState("idle");
setStatusMessage("Socket disconnected.");
}
};
const handleOutput = (payload) => {
appendOutput(payload?.data || "");
};
const handleClosed = () => {
setShellState("closed");
setSessionState("idle");
setStatusMessage("Shell closed.");
};
socket.on("disconnect", handleDisconnectEvent);
socket.on("vpn_shell_output", handleOutput);
socket.on("vpn_shell_closed", handleClosed);
return () => {
socket.off("disconnect", handleDisconnectEvent);
socket.off("vpn_shell_output", handleOutput);
socket.off("vpn_shell_closed", handleClosed);
if (localSocketRef.current) {
socket.disconnect();
}
};
}, [appendOutput, ensureSocket, sessionState]);
useEffect(() => {
return () => {
closeShell();
stopTunnel("component_unmount");
};
}, [closeShell, stopTunnel]);
const requestTunnel = useCallback(async () => {
if (!agentId) {
setStatusMessage("Agent ID is required to connect.");
return;
}
setLoading(true);
setStatusMessage("");
setSessionState("connecting");
setShellState("opening");
try {
const resp = await fetch("/api/tunnel/connect", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
const statusResp = await fetch(
`/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1`
);
const statusData = await statusResp.json().catch(() => ({}));
if (!statusResp.ok || statusData?.status !== "up") {
throw new Error(statusData?.error || "Tunnel not ready");
}
setTunnel({ ...data, ...statusData });
const socket = ensureSocket();
const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000);
if (openResp?.error) {
throw new Error(openResp.error);
}
setSessionState("connected");
setShellState("connected");
} catch (err) {
setSessionState("error");
setShellState("closed");
setStatusMessage(String(err.message || err));
} finally {
setLoading(false);
}
}, [agentId, ensureSocket]);
const handleSend = useCallback(
async (text) => {
const socket = ensureSocket();
if (!socket || sessionState !== "connected") return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "vpn_shell_send", { data: payload });
if (resp?.error) {
setStatusMessage("Send failed.");
}
},
[appendOutput, ensureSocket, sessionState]
);
const handleCopy = async () => { const handleCopy = async () => {
try { try {
await navigator.clipboard.writeText(output || ""); await navigator.clipboard.writeText(output || "");
@@ -267,329 +283,7 @@ export default function ReverseTunnelPowershell({ device }) {
} }
}; };
const measureTerminal = useCallback(() => { const isConnected = sessionState === "connected";
const el = terminalRef.current;
if (!el) return { cols: 120, rows: 32 };
const width = el.clientWidth || 960;
const height = el.clientHeight || 460;
const charWidth = 8.2;
const charHeight = 18;
const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300));
const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200));
return { cols, rows };
}, []);
const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => {
return new Promise((resolve) => {
let settled = false;
const timer = setTimeout(() => {
if (settled) return;
settled = true;
resolve({ error: "timeout" });
}, timeoutMs);
socket.emit(event, payload, (resp) => {
if (settled) return;
settled = true;
clearTimeout(timer);
resolve(resp || {});
});
});
}, []);
const pollLoop = useCallback(
(socket, tunnelId) => {
if (!socket || !tunnelId) return;
debugLog("pollLoop tick", { tunnelId });
setPolling(true);
pollTimerRef.current = setTimeout(async () => {
const resp = await emitAsync(socket, "ps_poll", {});
if (resp?.error) {
debugLog("pollLoop error", resp);
stopPolling();
disconnectSocket();
setPsStatus({});
setTunnel(null);
setSessionState("error");
return;
}
if (Array.isArray(resp?.output) && resp.output.length) {
appendOutput(resp.output.join(""));
}
if (resp?.status) {
setPsStatus(resp.status);
debugLog("pollLoop status", resp.status);
if (resp.status.closed) {
setSessionState("closed");
setTunnel(null);
setMilestones({ ...INITIAL_MILESTONES });
appendStatus(setShellSteps, "Shell closed");
appendStatus(setTunnelSteps, "Stopped");
appendStatus(setWebsocketSteps, "Relay stopped");
shellFlagsRef.current = { openSent: false, ack: false };
stopPolling();
return;
}
if (resp.status.open_sent && !shellFlagsRef.current.openSent) {
appendStatus(setShellSteps, "Opening remote shell");
shellFlagsRef.current.openSent = true;
}
if (resp.status.ack && !shellFlagsRef.current.ack) {
setSessionState("connected");
setMilestones((prev) => ({ ...prev, shellEstablished: true }));
appendStatus(setShellSteps, "Remote shell established");
shellFlagsRef.current.ack = true;
}
}
pollLoop(socket, tunnelId);
}, 520);
},
[appendOutput, emitAsync, stopPolling, disconnectSocket, appendStatus]
);
const handleDisconnect = useCallback(
async (reason = "operator_disconnect") => {
debugLog("handleDisconnect begin", { reason, tunnelId: tunnel?.tunnel_id, psStatus, sessionState });
setPsStatus({});
const socket = socketRef.current;
const tunnelId = tunnel?.tunnel_id;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
joinAttemptsRef.current = 0;
if (socket && tunnelId) {
const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close");
debugLog("emit CLOSE", { tunnelId });
socket.emit("send", { frame });
}
await stopTunnel(reason);
debugLog("stopTunnel issued", { tunnelId });
stopPolling();
disconnectSocket();
setTunnel(null);
setSessionState("closed");
setMilestones({ ...INITIAL_MILESTONES });
appendStatus(setTunnelSteps, "Stopped");
appendStatus(setWebsocketSteps, "Relay closed");
appendStatus(setShellSteps, "Shell closed");
shellFlagsRef.current = { openSent: false, ack: false };
debugLog("handleDisconnect finished", { tunnelId });
},
[appendStatus, disconnectSocket, stopPolling, stopTunnel, tunnel?.tunnel_id]
);
const handleResize = useCallback(() => {
if (!socketRef.current || sessionState === "idle") return;
const dims = measureTerminal();
socketRef.current.emit("ps_resize", dims);
}, [measureTerminal, sessionState]);
useEffect(() => {
const observer =
typeof ResizeObserver !== "undefined"
? new ResizeObserver(() => {
if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current);
resizeTimerRef.current = setTimeout(() => handleResize(), 200);
})
: null;
const el = terminalRef.current;
if (observer && el) observer.observe(el);
const onWinResize = () => handleResize();
window.addEventListener("resize", onWinResize);
return () => {
window.removeEventListener("resize", onWinResize);
if (observer && el) observer.unobserve(el);
};
}, [handleResize]);
const connectSocket = useCallback(
(lease, { isRetry = false } = {}) => {
if (!lease?.tunnel_id) return;
if (joinRetryRef.current) {
clearTimeout(joinRetryRef.current);
joinRetryRef.current = null;
}
if (!isRetry) {
joinAttemptsRef.current = 0;
}
disconnectSocket();
stopPolling();
setSessionState("waiting");
const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket", "polling"] });
socketRef.current = socket;
socket.on("connect_error", () => {
debugLog("socket connect_error");
setStatusSeverity("warning");
setStatusMessage("Tunnel namespace unavailable.");
setTunnel(null);
setSessionState("error");
appendStatus(setWebsocketSteps, "Relay connect error");
});
socket.on("disconnect", () => {
debugLog("socket disconnect", { tunnelId: tunnel?.tunnel_id });
stopPolling();
if (sessionState !== "closed") {
setSessionState("disconnected");
setStatusSeverity("warning");
setStatusMessage("Socket disconnected.");
setTunnel(null);
appendStatus(setWebsocketSteps, "Relay disconnected");
}
});
socket.on("connect", async () => {
debugLog("socket connect", { tunnelId: lease.tunnel_id });
setMilestones((prev) => ({ ...prev, operatorAttached: true }));
setStatusSeverity("info");
setStatusMessage("Joining tunnel...");
appendStatus(setWebsocketSteps, "Relay connected");
const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id }, 5000);
if (joinResp?.error) {
const attempt = (joinAttemptsRef.current += 1);
const isTimeout = joinResp.error === "timeout";
if (joinResp.error === "unknown_tunnel") {
setSessionState("waiting_agent");
setStatusSeverity("info");
setStatusMessage("Waiting for agent to establish tunnel...");
appendStatus(setWebsocketSteps, "Waiting for agent");
} else if (isTimeout || joinResp.error === "attach_failed") {
setSessionState("waiting_agent");
setStatusSeverity("warning");
setStatusMessage("Tunnel join timed out. Retrying...");
appendStatus(setWebsocketSteps, `Join retry ${attempt}`);
} else {
debugLog("join error", joinResp);
setSessionState("error");
setStatusSeverity("error");
setStatusMessage(joinResp.error);
appendStatus(setWebsocketSteps, `Join failed: ${joinResp.error}`);
return;
}
if (attempt <= 5) {
joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 800);
} else {
setSessionState("error");
setTunnel(null);
setStatusSeverity("warning");
setStatusMessage("Operator could not attach to tunnel. Try Connect again.");
appendStatus(setWebsocketSteps, "Join failed after retries");
}
return;
}
appendStatus(setWebsocketSteps, "Relay joined");
const dims = measureTerminal();
debugLog("ps_open emit", { tunnelId: lease.tunnel_id, dims });
const openResp = await emitAsync(socket, "ps_open", dims, 5000);
if (openResp?.error && openResp.error === "ps_unsupported") {
// Suppress warming message; channel will settle once agent attaches.
}
if (!shellFlagsRef.current.openSent) {
appendStatus(setShellSteps, "Opening remote shell");
shellFlagsRef.current.openSent = true;
}
appendOutput("");
setSessionState("waiting_agent");
pollLoop(socket, lease.tunnel_id);
handleResize();
});
},
[appendOutput, appendStatus, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling]
);
const requestTunnel = useCallback(async () => {
if (tunnel && sessionState !== "closed" && sessionState !== "idle") {
setStatusSeverity("info");
setStatusMessage("");
connectSocket(tunnel);
return;
}
debugLog("requestTunnel", { agentId, connectionType });
if (!agentId) {
setStatusSeverity("warning");
setStatusMessage("Agent ID is required to request a tunnel.");
return;
}
if (connectionType !== "ps") {
setStatusSeverity("warning");
setStatusMessage("Only PowerShell is supported right now.");
return;
}
resetState();
setSessionState("requesting");
setStatusSeverity("info");
setStatusMessage("");
appendStatus(setTunnelSteps, "Requesting lease");
try {
const resp = await fetch("/api/tunnel/request", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: DOMAIN_REMOTE_SHELL }),
});
const data = await resp.json().catch(() => ({}));
if (!resp.ok) {
const err = data?.error || `HTTP ${resp.status}`;
setSessionState("error");
setStatusSeverity(err === "domain_limit" ? "warning" : "error");
setStatusMessage("");
return;
}
setMilestones((prev) => ({ ...prev, tunnelReady: true }));
setTunnel(data);
setStatusMessage("");
setSessionState("lease_issued");
appendStatus(
setTunnelSteps,
data?.tunnel_id
? `Lease issued (${data.tunnel_id.slice(0, 8)} @ Port ${data.port || "-"})`
: "Lease issued"
);
connectSocket(data);
} catch (e) {
setSessionState("error");
setStatusSeverity("error");
setStatusMessage("");
appendStatus(setTunnelSteps, "Lease request failed");
}
}, [DOMAIN_REMOTE_SHELL, agentId, appendStatus, connectSocket, connectionType, resetState]);
const handleSend = useCallback(
async (text) => {
const socket = socketRef.current;
if (!socket) return;
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
appendOutput(`\nPS> ${text}\n`);
setInput("");
const resp = await emitAsync(socket, "ps_send", { data: payload });
if (resp?.error) {
setStatusSeverity("warning");
setStatusMessage("");
}
},
[appendOutput, emitAsync]
);
const isConnected = sessionState === "connected" || (psStatus?.ack && !psStatus?.closed);
const isClosed = sessionState === "closed" || psStatus?.closed;
const isBusy =
sessionState === "requesting" ||
sessionState === "waiting" ||
sessionState === "waiting_agent" ||
sessionState === "lease_issued";
const canStart = Boolean(agentId) && !isBusy;
useEffect(() => {
const handleUnload = () => {
stopTunnel("window_unload");
};
if (tunnel?.tunnel_id) {
window.addEventListener("beforeunload", handleUnload);
return () => window.removeEventListener("beforeunload", handleUnload);
}
return undefined;
}, [stopTunnel, tunnel?.tunnel_id]);
const sessionChips = [ const sessionChips = [
tunnel?.tunnel_id tunnel?.tunnel_id
? { ? {
@@ -598,58 +292,43 @@ export default function ReverseTunnelPowershell({ device }) {
icon: <LinkIcon sx={{ fontSize: 18 }} />, icon: <LinkIcon sx={{ fontSize: 18 }} />,
} }
: null, : null,
tunnel?.port tunnel?.virtual_ip
? { ? {
label: `Port ${tunnel.port}`, label: `IP ${String(tunnel.virtual_ip).split("/")[0]}`,
color: MAGIC_UI.accentA, color: MAGIC_UI.accentA,
icon: <PortIcon sx={{ fontSize: 18 }} />, icon: <IpIcon sx={{ fontSize: 18 }} />,
} }
: null, : null,
].filter(Boolean); ].filter(Boolean);
return ( return (
<Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}> <Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}>
<Box> <Stack direction={{ xs: "column", sm: "row" }} spacing={1.5} alignItems={{ xs: "flex-start", sm: "center" }}>
<Stack <Button
direction={{ xs: "column", sm: "row" }} size="small"
spacing={1.5} startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
alignItems={{ xs: "flex-start", sm: "center" }} sx={gradientButtonSx}
justifyContent={{ xs: "flex-start", sm: "flex-end" }} disabled={loading || (!isConnected && !agentId)}
onClick={isConnected ? handleDisconnect : requestTunnel}
> >
<TextField {isConnected ? "Disconnect" : "Connect"}
select </Button>
label="Connection Protocol" <Stack direction="row" spacing={1}>
size="small" {sessionChips.map((chip) => (
value={connectionType} <Chip
onChange={(e) => setConnectionType(e.target.value)} key={chip.label}
sx={{ icon={chip.icon}
minWidth: 180, label={chip.label}
"& .MuiInputBase-root": { sx={{
backgroundColor: "rgba(12,18,35,0.85)", borderRadius: 999,
color: MAGIC_UI.textBright, color: chip.color,
borderRadius: 1.5, border: `1px solid ${MAGIC_UI.panelBorder}`,
}, backgroundColor: "rgba(8,12,24,0.65)",
"& fieldset": { borderColor: MAGIC_UI.panelBorder }, }}
"&:hover fieldset": { borderColor: MAGIC_UI.accentA }, />
}} ))}
>
<MenuItem value="ps">PowerShell</MenuItem>
</TextField>
<Tooltip title={isConnected ? "Disconnect session" : "Connect to agent"}>
<span>
<Button
size="small"
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
sx={gradientButtonSx}
disabled={!isConnected && !canStart}
onClick={isConnected ? handleDisconnect : requestTunnel}
>
{isConnected ? "Disconnect" : "Connect"}
</Button>
</span>
</Tooltip>
</Stack> </Stack>
</Box> </Stack>
<Box <Box
sx={{ sx={{
@@ -665,7 +344,7 @@ export default function ReverseTunnelPowershell({ device }) {
overflow: "hidden", overflow: "hidden",
}} }}
> >
{isBusy ? <LinearProgress color="info" sx={{ height: 3 }} /> : null} {loading ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
<Box <Box
ref={terminalRef} ref={terminalRef}
sx={{ sx={{
@@ -728,11 +407,7 @@ export default function ReverseTunnelPowershell({ device }) {
size="small" size="small"
value={input} value={input}
disabled={!isConnected} disabled={!isConnected}
placeholder={ placeholder={isConnected ? "Enter PowerShell command and press Enter" : "Connect to start sending commands"}
isConnected
? "Enter PowerShell command and press Enter"
: "Connect to start sending commands"
}
onChange={(e) => setInput(e.target.value)} onChange={(e) => setInput(e.target.value)}
onKeyDown={(e) => { onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) { if (e.key === "Enter" && !e.shiftKey) {
@@ -753,43 +428,19 @@ export default function ReverseTunnelPowershell({ device }) {
/> />
</Box> </Box>
</Box> </Box>
<Stack spacing={0.3} sx={{ mt: 1.25 }}>
<Typography <Stack spacing={0.3} sx={{ mt: 1 }}>
variant="body2" <Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
sx={{ Tunnel: {sessionState === "connected" ? "Active" : sessionState}
color: milestones.tunnelReady ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Tunnel:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{tunnelSteps.join(" > ")}
</Typography>
</Typography> </Typography>
<Typography <Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
variant="body2" Shell: {shellState === "connected" ? "Ready" : shellState}
sx={{
color: milestones.operatorAttached ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Websocket:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{websocketSteps.join(" > ")}
</Typography>
</Typography> </Typography>
<Typography {statusMessage ? (
variant="body2" <Typography variant="body2" sx={{ color: "#ff7b89" }}>
sx={{ {statusMessage}
color: milestones.shellEstablished ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
fontWeight: 700,
}}
>
Remote Shell:{" "}
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
{shellSteps.join(" > ")}
</Typography> </Typography>
</Typography> ) : null}
</Stack> </Stack>
</Box> </Box>
); );

View File

@@ -20,8 +20,9 @@ Use this doc for agent-only work (Borealis agent runtime under `Data/Agent` →
- Validates script payloads with backend-issued Ed25519 signatures before execution. - Validates script payloads with backend-issued Ed25519 signatures before execution.
- Outbound-only; API/WebSocket calls flow through `AgentHttpClient.ensure_authenticated` for proactive refresh. Logs bootstrap, enrollment, token refresh, and signature events in `Agent/Logs/`. - Outbound-only; API/WebSocket calls flow through `AgentHttpClient.ensure_authenticated` for proactive refresh. Logs bootstrap, enrollment, token refresh, and signature events in `Agent/Logs/`.
## Reverse Tunnels ## Reverse VPN Tunnels
- Design, orchestration, domains, limits, and lifecycle are documented in `Docs/Codex/REVERSE_TUNNELS.md`. Agent role implementation lives in `Data/Agent/Roles/role_ReverseTunnel.py` with per-domain protocol handlers under `Data/Agent/Roles/Reverse_Tunnels/`. - WireGuard reverse VPN design and lifecycle live in `Docs/Codex/REVERSE_TUNNELS.md` and `Docs/Codex/Reverse_VPN_Tunnel_Deployment.md`.
- Agent roles: `Data/Agent/Roles/role_WireGuardTunnel.py` (tunnel lifecycle) and `Data/Agent/Roles/role_VpnShell.py` (VPN PowerShell TCP server).
## Execution Contexts & Roles ## Execution Contexts & Roles
- Auto-discovers roles from `Data/Agent/Roles/`; no loader changes needed. - Auto-discovers roles from `Data/Agent/Roles/`; no loader changes needed.

View File

@@ -23,9 +23,10 @@ Use this doc for Engine work (successor to the legacy server). For shared guidan
- Enrollment: operator approvals, conflict detection, auditor recording, pruning of expired codes/refresh tokens. - Enrollment: operator approvals, conflict detection, auditor recording, pruning of expired codes/refresh tokens.
- Background jobs and service adapters maintain compatibility with legacy DB schemas while enabling gradual API takeover. - Background jobs and service adapters maintain compatibility with legacy DB schemas while enabling gradual API takeover.
## Reverse Tunnels ## Reverse VPN Tunnels
- Full design and lifecycle are in `Docs/Codex/REVERSE_TUNNELS.md` (domains, limits, framing, APIs, stop path, UI hooks). - WireGuard reverse VPN design and lifecycle live in `Docs/Codex/REVERSE_TUNNELS.md` and `Docs/Codex/Reverse_VPN_Tunnel_Deployment.md`.
- Engine orchestrator is `Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py` with domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`. - Engine orchestrator: `Data/Engine/services/VPN/vpn_tunnel_service.py` with WireGuard manager `Data/Engine/services/VPN/wireguard_server.py`.
- UI shell bridge: `Data/Engine/services/WebSocket/vpn_shell.py`.
## WebUI & WebSocket Migration ## WebUI & WebSocket Migration
- Static/template handling: `Data/Engine/services/WebUI`; deployment copy paths are wired through `Borealis.ps1` with TLS-aware URL generation. - Static/template handling: `Data/Engine/services/WebUI`; deployment copy paths are wired through `Borealis.ps1` with TLS-aware URL generation.

View File

@@ -1,92 +1,61 @@
# Borealis Reverse Tunnels Operator & Developer Guide # Borealis Reverse VPN Tunnels (WireGuard) Operator & Developer Guide
This document is the single reference for how Borealis reverse tunnels are organized, secured, and orchestrated. It is written for Codex agents extending the feature (new protocols, UI, or policy changes). This document is the reference for Borealis reverse VPN tunnels built on WireGuard. The legacy WebSocket framing and domain-lane tunnel stack has been retired; the system now uses a single outbound WireGuard tunnel per agent with host-only routing and per-device ACLs.
## 1) High-Level Model ## 1) High-Level Model
- Outbound-only: Agents initiate all tunnel sockets to the Engine. No inbound openings on devices. - Outbound-only: agents establish WireGuard tunnels to the Engine; no inbound access on devices.
- Transport: WebSocket-over-TLS carrying a binary frame header (version | msg_type | flags | reserved | channel_id | length) plus payload. - Transport: WireGuard/UDP on port 30000.
- Leases: Engine issues short-lived leases per agent/domain/protocol. Each lease binds a tunnel_id to an ephemeral Engine port and a signed token. - Sessions: one live VPN tunnel per agent; multiple operators share it.
- Domains: Concurrency “lanes” keep protocols isolated: `remote-interactive-shell` (2), `remote-management` (1), `remote-video` (2). Legacy aliases (`ps`, etc.) normalize into these lanes. - Routing: host-only /32 per agent; AllowedIPs restricted to the agent /32 and engine /32; no client-to-client.
- Channels: Logical streams inside a tunnel (channel_id u32). PS uses channel 1; future protocols can open more channels per tunnel as needed. - Idle timeout: 15 minutes of no operator activity; no grace period.
- Tear-down: Idle/grace timeouts plus explicit operator stop. Closing a tunnel must close its protocol channel(s) and kill the agent process for interactive shells. - Keys: WireGuard server keys under `Engine/Certificates/VPN_Server`; client keys under `Agent/Borealis/Certificates/VPN_Client`.
## 2) Engine Components ## 2) Engine Components
- Orchestrator: `Data/Engine/services/WebSocket/Agent/reverse_tunnel_orchestrator.py` - Orchestrator: `Data/Engine/services/VPN/vpn_tunnel_service.py`
- Lease manager: Port pool allocator, domain limit enforcement, idle/grace sweeper. - Allocates per-agent /32, issues short-lived orchestration tokens, enforces single-session.
- Token issuer/validator: Binds agent_id, tunnel_id, domain, protocol, port, expires_at. - Starts/stops WireGuard listener, applies firewall rules, idles out on inactivity.
- Bridge: Maps agent sockets ↔ operator sockets; stores per-tunnel protocol server instances. - Emits Socket.IO events: `vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`.
- Logging: `Engine/Logs/reverse_tunnel.log` plus Device Activity start/stop entries. - WireGuard manager: `Data/Engine/services/VPN/wireguard_server.py`
- Stop path: `stop_tunnel` closes protocol servers, emits `reverse_tunnel_stop` to agents, releases lease/bridge. - Generates server keys, renders config, manages `wireguard.exe` tunnel service, applies ACL rules.
- Protocol registry: Domain/protocol handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`: - PowerShell bridge: `Data/Engine/services/WebSocket/vpn_shell.py`
- `remote_interactive_shell/Protocols/Powershell.py` (live), `Bash.py` (placeholder). - Proxies UI shell input/output to the agents TCP shell server over WireGuard.
- `remote_management/Protocols/SSH.py`, `WinRM.py` (placeholders). - Logging: `Engine/Logs/reverse_tunnel.log` plus Device Activity entries.
- `remote_video/Protocols/VNC.py`, `RDP.py`, `WebRTC.py` (placeholders).
- API Endpoints:
- `POST /api/tunnel/request` → allocates lease, returns {tunnel_id, port, token, idle_seconds, grace_seconds, domain, protocol}.
- `DELETE /api/tunnel/<tunnel_id>` → operator-driven stop; pushes stop to agent and releases the lease.
- Domain default for PowerShell requests is `remote-interactive-shell` (legacy `ps` still accepted).
- Operator Socket.IO namespace `/tunnel`:
- `join`, `send`, `poll`, `ps_open`, `ps_send`, `ps_resize`, `ps_poll`.
- Operator socket disconnect triggers `stop_tunnel` if no other operators remain attached.
- WebUI (current): `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` requests PS leases in `remote-interactive-shell`, sends CLOSE frames, and calls DELETE on disconnect/unload.
## 3) Agent Components ## 3) API Endpoints
- Role: `Data/Agent/Roles/role_ReverseTunnel.py` - `POST /api/tunnel/connect` → issues session material (tunnel_id, token, virtual_ip, endpoint, allowed_ports, idle_seconds).
- Validates signed lease tokens; enforces domain limits (2/1/2 with legacy fallbacks). - `GET /api/tunnel/status` → returns up/down status for an agent.
- Outbound TLS WS connect to assigned port; heartbeats + idle/grace watchdog; stop_all closes channels and sends CLOSE. - `GET /api/tunnel/connect/status` → alias for status (used by UI before shell open).
- Protocol registry: loads handlers from `Data/Agent/Roles/Reverse_Tunnels/*/Protocols/*` (PowerShell live; others stubbed to close unsupported channels cleanly). - `DELETE /api/tunnel/disconnect` → immediate teardown (agent + engine cleanup).
- PowerShell channel: `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` (pipes-only, no PTY); re-exported under `Reverse_Tunnels/remote_interactive_shell/Protocols/Powershell.py`. - `GET /api/device/vpn_config/<agent_id>` → read per-agent allowed ports.
- Logging: `Agent/Logs/reverse_tunnel.log` with channel/tunnel lifecycle. - `PUT /api/device/vpn_config/<agent_id>` → update allowed ports.
## 4) Framing, Heartbeats, Close ## 4) Agent Components
- Header: version(1) | msg_type(1) | flags(1) | reserved(1) | channel_id(u32 LE) | length(u32 LE). - Tunnel lifecycle: `Data/Agent/Roles/role_WireGuardTunnel.py`
- Messages: CONNECT/ACK, CHANNEL_OPEN/ACK, DATA, CONTROL (resize), WINDOW_UPDATE (reserved), HEARTBEAT (ping/pong), CLOSE. - Validates orchestration tokens, starts/stops WireGuard client service, enforces idle.
- Close codes: ok, idle_timeout, grace_expired, protocol_error, auth_failed, server_shutdown, agent_shutdown, domain_limit, unexpected_disconnect. - Shell server: `Data/Agent/Roles/role_VpnShell.py`
- Heartbeats: Engine → Agent loop; idle/grace sweeper ~15s on Engine; Agent watchdog closes on idle/grace. - TCP PowerShell server bound to `0.0.0.0:47001`, restricted to VPN subnet (10.255.x.x).
- Logging: `Agent/Logs/reverse_tunnel.log`.
## 5) Lifecycle (PowerShell example) ## 5) Security & Auth
1. UI calls `POST /api/tunnel/request` with agent_id, protocol=ps, domain=remote-interactive-shell. - TLS pinned for Engine API/Socket.IO.
2. Engine allocates port/tunnel_id, signs token, starts listener, pushes `reverse_tunnel_start` to agent. - Orchestration tokens signed via Engine Ed25519 key; agent verifies signatures and stores the signing key.
3. Agent dials WS to assigned port, sends CONNECT with token. Engine validates, binds bridge, sends CONNECT_ACK + heartbeat. - WireGuard AllowedIPs /32; no LAN routes; client-to-client blocked.
4. Operator Socket.IO `/tunnel` joins; Engine attaches operator, instantiates PS server, issues CHANNEL_OPEN. - Engine firewall rules enforce per-device allowed ports.
5. Agent launches PowerShell (pipes), streams stdout/stderr as DATA; operator input via `ps_send`; optional resize via `ps_resize` (no-op on agent pipes).
6. On operator Disconnect/tab close, UI sends CLOSE frame and calls DELETE; Engine stop path notifies agent (`reverse_tunnel_stop`), closes channel, releases lease/domain slot.
7. Idle/grace expiry or agent disconnect also triggers close/release; domain slots free immediately.
## 6) Security & Auth ## 6) UI
- TLS: Reuse existing pinned bundle; outbound-only agent sockets. - Device details now include an “Advanced Config” tab for per-device allowed ports.
- Token: short-lived, binds agent_id/tunnel_id/domain/protocol/port/expires_at; optional signature verification (Ed25519 signer when configured). - PowerShell MVP reuses `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` with WireGuard APIs + VPN shell events.
- Operator auth: uses existing Engine session/cookie/bearer for `/tunnel` namespace and API endpoints.
## 7) Configuration Knobs (defaults) ## 7) Extending to New Protocols
- Port pool: 3000040000; fixed port optional (context settings). - Add protocol ports to the device allowlist and UI toggles.
- Idle timeout: 3600s; Grace timeout: 3600s. - Reuse the existing VPN tunnel; no new transport/domain lanes required.
- Heartbeat interval: 20s (Engine → Agent).
- Domain limits: remote-interactive-shell=2, remote-management=1, remote-video=2; legacy aliases preserved.
- Log path: `Engine/Logs/reverse_tunnel.log`; `Agent/Logs/reverse_tunnel.log`.
## 8) Logs & Telemetry ## 8) Legacy Removal
- Engine: lease events, socket events, close reasons in `reverse_tunnel.log`; Device Activity start/stop with tunnel_id/operator_id when available. - WebSocket tunnel domains, protocol handlers, and domain limits are removed.
- Agent: role lifecycle, channel start/stop, errors in `reverse_tunnel.log`. - No `/tunnel` Socket.IO namespace or framed protocol messages remain.
## 9) Extending to New Protocols ## 9) Change Log (not exhaustive)
- Add Engine handler under the appropriate domain folder and register in the orchestrators protocol registry. - 2025-11-30: Legacy WebSocket tunnel scaffold introduced (lease manager, framing, tokens).
- Add Agent handler under matching domain folder; update role registry to load it. - 2025-12-06: Legacy PowerShell handler simplified to pipes-only; UI status tweaks.
- Define channel open semantics (metadata), DATA/CONTROL usage, and close behavior. - 2025-12-18: Legacy domain lanes added (`remote-interactive-shell`, `remote-management`, `remote-video`) with limits.
- Update API/UI to allow selecting the protocol/domain and to send protocol-specific controls. - 2025-12-20: WireGuard reverse VPN migration complete; legacy WebSocket tunnels retired; VPN shell bridge + new APIs.
## 10) Outstanding Work
- Implement real handlers for Bash/SSH/WinRM/RDP/VNC/WebRTC and surface in UI.
- Add tests for DELETE stop path, per-domain limits, and browser disconnect cleanup.
- Consider a binary WebSocket browser bridge to replace Socket.IO for high-throughput protocols.
## 11) Risks & Watchpoints
- Eventlet/asyncio coexistence: tunnel loop runs on its own thread/loop; avoid blocking Socket.IO handlers.
- Port exhaustion: handle allocation failures cleanly; always release on stop/idle/grace.
- Buffer growth: add back-pressure before enabling high-throughput protocols.
- Security: strict token binding (agent_id/tunnel_id/domain/protocol/port/expiry) and TLS; reject framing errors.
## 12) Change Log (not exhaustive)
- 2025-11-30: Initial scaffold (lease manager, framing, tokens, API, Agent role, PS handlers).
- 2025-12-06: Simplified PS to pipes-only; improved handler imports; UI status tweaks.
- 2025-12-18: Domain lanes introduced (`remote-interactive-shell`, `remote-management`, `remote-video`) with limits 2/1/2; protocol handlers reorganized under `Reverse_Tunnels/*/Protocols/*`; orchestrator renamed to `reverse_tunnel_orchestrator.py`; explicit stop API/Socket.IO cleanup; WebUI Disconnect/unload calls DELETE + CLOSE for immediate teardown.

View File

@@ -42,8 +42,8 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi
- [x] Do not start any tunnel yet. - [x] Do not start any tunnel yet.
- Linux: do nothing yet (see later section). - Linux: do nothing yet (see later section).
- Checkpoint tests: - Checkpoint tests:
- [x] WireGuard binaries available in agent runtime. - [ ] WireGuard binaries available in agent runtime.
- [x] WireGuard driver installed and visible. - [ ] WireGuard driver installed and visible.
### 2) Engine VPN Server & ACLs — Milestone: Engine VPN Server & ACLs (Windows) ### 2) Engine VPN Server & ACLs — Milestone: Engine VPN Server & ACLs (Windows)
- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
@@ -54,15 +54,15 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi
- [x] Do not push DNS or LAN routes; host-only reachability engine IP ↔ agent virtual /32. - [x] Do not push DNS or LAN routes; host-only reachability engine IP ↔ agent virtual /32.
- ACL layer: - ACL layer:
- [x] Default allowlist per agent derived from OS (Windows: RDP 3389, WinRM 5985/5986, PS remoting ports; include VNC/WebRTC defaults as desired). - [x] Default allowlist per agent derived from OS (Windows: RDP 3389, WinRM 5985/5986, PS remoting ports; include VNC/WebRTC defaults as desired).
- [x] Allow operator overrides per agent; enforce at engine firewall layer. (rule plans produced; application wiring pending) - [x] Allow operator overrides per agent; enforce at engine firewall layer.
- Keys/Certs: - Keys/Certs:
- [x] Prefer reusing existing Engine cert infrastructure for signing orchestration tokens. Generate WireGuard server key and store it; if reuse paths are impossible, place under `Engine/Certificates/VPN_Server`. - [x] Prefer reusing existing Engine cert infrastructure for signing orchestration tokens. Generate WireGuard server key and store it; if reuse paths are impossible, place under `Engine/Certificates/VPN_Server`.
- [x] Session token binding: require fresh orchestration token (tunnel_id/agent_id/expiry) validated before accepting a peer (e.g., via pre-shared keys or control-plane validation before adding peer). - [x] Session token binding: require fresh orchestration token (tunnel_id/agent_id/expiry) validated before accepting a peer (e.g., via pre-shared keys or control-plane validation before adding peer).
- Logging: server logs to `Engine/Logs/reverse_tunnel.log` (or renamed consistently). [x] - Logging: server logs to `Engine/Logs/reverse_tunnel.log` (or renamed consistently). [x]
- Checkpoint tests: - Checkpoint tests:
- [x] Engine starts WireGuard listener locally on 30000. - [ ] Engine starts WireGuard listener locally on 30000.
- [x] Only engine IP reachable; client-to-client blocked. - [ ] Only engine IP reachable; client-to-client blocked.
- [x] Peers without valid token/key are rejected. - [ ] Peers without valid token/key are rejected.
### 3) Agent VPN Client & Lifecycle — Milestone: Agent VPN Client & Lifecycle (Windows) ### 3) Agent VPN Client & Lifecycle — Milestone: Agent VPN Client & Lifecycle (Windows)
- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise). - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
@@ -83,60 +83,64 @@ At each milestone: pause, run the listed checks, talk to the operator, and commi
- [ ] Idle timeout fires at ~15 minutes of inactivity. - [ ] Idle timeout fires at ~15 minutes of inactivity.
### 4) API & Service Orchestration — Milestone: API & Service Orchestration (Windows) ### 4) API & Service Orchestration — Milestone: API & Service Orchestration (Windows)
- Replace legacy tunnel APIs with: - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
- `POST /api/tunnel/connect` → tunnel_id, token, WG client config (keys, endpoint, allowed IPs), virtual IP, idle_seconds (900). - [x] Replace legacy tunnel APIs with:
- `GET /api/tunnel/status` → up/down, virtual IP, connected operators. - [x] `POST /api/tunnel/connect` → tunnel_id, token, WG client config (keys, endpoint, allowed IPs), virtual IP, idle_seconds (900).
- `DELETE /api/tunnel/disconnect` → immediate teardown and lease release. - [x] `GET /api/tunnel/status` → up/down, virtual IP, connected operators.
- Engine orchestrator: - [x] `DELETE /api/tunnel/disconnect` → immediate teardown and lease release.
- Manages single tunnel per agent; tracks tunnel_id, virtual IP, token expiry. - [x] Engine orchestrator:
- Emits start/stop signals to agent (rename events as needed). - [x] Manages single tunnel per agent; tracks tunnel_id, virtual IP, token expiry.
- Cleans peer/routing state on stop. - [x] Emits start/stop signals to agent (rename events as needed).
- Token issuance: short-lived, binds agent_id/tunnel_id/port/expiry; validated before adding peer. - [x] Cleans peer/routing state on stop.
- Remove domain limits; remove channel/protocol handler registry for tunnels. - [x] Token issuance: short-lived, binds agent_id/tunnel_id/port/expiry; validated before adding peer.
- [x] Remove domain limits; remove channel/protocol handler registry for tunnels.
- Checkpoint tests: - Checkpoint tests:
- API happy path: connect → status → disconnect. - [ ] API happy path: connect → status → disconnect.
- Reject stale/second connect for same agent while active. - [ ] Reject stale/second connect for same agent while active.
### 5) UI Advanced Config & Operator Flow (PowerShell MVP) — Milestone: UI Advanced Config & Operator Flow (Windows, PowerShell MVP) ### 5) UI Advanced Config & Operator Flow (PowerShell MVP) — Milestone: UI Advanced Config & Operator Flow (Windows, PowerShell MVP)
- In `Data/Engine/web-interface/src/Devices/Device_Details.jsx`, add “Advanced Config” tab: - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
- “Reverse VPN Tunnel - Allowed Ports” with toggles per protocol. - [x] In `Data/Engine/web-interface/src/Devices/Device_Details.jsx`, add “Advanced Config” tab:
- Defaults by OS (Windows: RDP/WinRM/PS; All: VNC/WebRTC; allow operator overrides). - [x] “Reverse VPN Tunnel - Allowed Ports” with toggles per protocol.
- PowerShell MVP: - [x] Defaults by OS (Windows: RDP/WinRM/PS; All: VNC/WebRTC; allow operator overrides).
- Reuse `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` as the base UI. - [x] PowerShell MVP:
- Rewire to new APIs and virtual IP flow. - [x] Reuse `Data/Engine/web-interface/src/Devices/ReverseTunnel/Powershell.jsx` as the base UI.
- Keep live web terminal behavior (WebSocket or equivalent) so operator input streams to remote PowerShell and outputs stream back in real time over the VPN tunnel. - [x] Rewire to new APIs and virtual IP flow.
- Ensure tunnel is up via `/api/tunnel/connect/status` before opening the terminal; call `/api/tunnel/disconnect` on exit/tab close. - [x] Keep live web terminal behavior (WebSocket or equivalent) so operator input streams to remote PowerShell and outputs stream back in real time over the VPN tunnel.
- [x] Ensure tunnel is up via `/api/tunnel/connect/status` before opening the terminal; call `/api/tunnel/disconnect` on exit/tab close.
- Later protocols (RDP/SSH/etc.) can follow once MVP is proven, but do not block on them for this milestone. - Later protocols (RDP/SSH/etc.) can follow once MVP is proven, but do not block on them for this milestone.
- Checkpoint tests: - Checkpoint tests:
- UI can start a tunnel, launch PowerShell terminal, send commands, receive live output, and tear down. - [ ] UI can start a tunnel, launch PowerShell terminal, send commands, receive live output, and tear down.
- Toggles change ACL behavior (engine→agent reachability) as expected. - [ ] Toggles change ACL behavior (engine→agent reachability) as expected.
### 6) Legacy Tunnel Removal & Cleanup — Milestone: Legacy Tunnel Removal & Cleanup (Windows) ### 6) Legacy Tunnel Removal & Cleanup — Milestone: Legacy Tunnel Removal & Cleanup (Windows)
- Remove/retire: - Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
- Engine `reverse_tunnel_orchestrator` and domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`. - [x] Remove/retire:
- Agent `role_ReverseTunnel.py` and protocol handlers. - [x] Engine `reverse_tunnel_orchestrator` and domain handlers under `Data/Engine/services/WebSocket/Agent/Reverse_Tunnels/`.
- WebUI components tied to the old Socket.IO tunnel namespace. - [x] Agent `role_ReverseTunnel.py` and protocol handlers.
- Update docs and references to point to the new WireGuard VPN flow; keep change log entries. - [x] WebUI components tied to the old Socket.IO tunnel namespace.
- Ensure no lingering domain limits/config knobs remain. - [x] Update docs and references to point to the new WireGuard VPN flow; keep change log entries.
- [x] Ensure no lingering domain limits/config knobs remain.
- Checkpoint tests: - Checkpoint tests:
- Codebase builds/starts without references to legacy tunnel modules. - [ ] Codebase builds/starts without references to legacy tunnel modules.
- UI no longer calls old APIs or Socket.IO tunnel namespace. - [ ] UI no longer calls old APIs or Socket.IO tunnel namespace.
### 7) End-to-End Validation — Milestone: End-to-End Validation (Windows) ### 7) End-to-End Validation — Milestone: End-to-End Validation (Windows)
- Agents editing this document should mark tasks they complete with `[x]` (leave `[ ]` otherwise).
- Functional: - Functional:
- Windows agent: WireGuard connect on port 30000; PowerShell MVP fully live in the web terminal; RDP/WinRM reachable over tunnel as configured. - [ ] Windows agent: WireGuard connect on port 30000; PowerShell MVP fully live in the web terminal; RDP/WinRM reachable over tunnel as configured.
- Idle timeout at 15 minutes; operator disconnect stops tunnel immediately. - [ ] Idle timeout at 15 minutes; operator disconnect stops tunnel immediately.
- Security: - Security:
- Client-to-client blocked. - [ ] Client-to-client blocked.
- Only engine IP reachable; per-agent ACL enforces allowed ports. - [ ] Only engine IP reachable; per-agent ACL enforces allowed ports.
- Token enforcement blocks stale/unauthorized sessions. - [ ] Token enforcement blocks stale/unauthorized sessions.
- Resilience: - Resilience:
- Restart engine: WireGuard server starts; no orphaned routes. - [ ] Restart engine: WireGuard server starts; no orphaned routes.
- Restart agent: adapter persists; tunnel stays down until requested. - [ ] Restart agent: adapter persists; tunnel stays down until requested.
- Logging/audit: - Logging/audit:
- Connect/disconnect/idle/stop reasons recorded in reverse_tunnel.log (Engine/Agent) and Device Activity. - [ ] Connect/disconnect/idle/stop reasons recorded in reverse_tunnel.log (Engine/Agent) and Device Activity.
- Checkpoint tests: - Checkpoint tests:
- Run the above matrix; gather logs for operator review before final commit. - [ ] Run the above matrix; gather logs for operator review before final commit.
## Linux (Deferred) — Do Not Implement Yet ## Linux (Deferred) — Do Not Implement Yet
- When greenlit, mirror the structure above for Linux: - When greenlit, mirror the structure above for Linux: