mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-18 17:55:48 -07:00
Overhaul of VPN Codebase
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
"""Reverse tunnel protocol modules (placeholder package)."""
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
"""PowerShell channel implementation for reverse tunnel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
|
||||
MSG_DATA = 0x05
|
||||
MSG_WINDOW_UPDATE = 0x06
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
|
||||
# Close codes (mirrored from engine framing)
|
||||
CLOSE_OK = 0
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class PowershellChannel:
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
self._reader_task = None
|
||||
self._writer_task = None
|
||||
self._stdin_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._proc: Optional[asyncio.subprocess.Process] = None
|
||||
self._exit_code: Optional[int] = None
|
||||
self._frame_cls = getattr(role, "_frame_cls", None)
|
||||
|
||||
# ------------------------------------------------------------------ Helpers
|
||||
def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0):
|
||||
frame_cls = self._frame_cls
|
||||
if frame_cls is None:
|
||||
return None
|
||||
try:
|
||||
return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _send_frame(self, frame) -> None:
|
||||
if frame is None:
|
||||
return
|
||||
await self.role._send_frame(self.tunnel, frame)
|
||||
|
||||
async def _send_close(self, code: int, reason: str) -> None:
|
||||
try:
|
||||
close_frame = getattr(self.role, "close_frame")
|
||||
if callable(close_frame):
|
||||
await self._send_frame(close_frame(self.channel_id, code, reason))
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
frame = self._make_frame(
|
||||
MSG_CLOSE,
|
||||
payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"),
|
||||
)
|
||||
await self._send_frame(frame)
|
||||
|
||||
def _powershell_argv(self) -> list:
|
||||
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
|
||||
shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe"
|
||||
# Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin.
|
||||
return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"]
|
||||
|
||||
# ------------------------------------------------------------------ Lifecycle
|
||||
async def start(self) -> None:
|
||||
if sys.platform.lower().startswith("win") is False:
|
||||
self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True)
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
|
||||
return
|
||||
|
||||
argv = self._powershell_argv()
|
||||
self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} mode=pipes")
|
||||
|
||||
# Pipes (no PTY).
|
||||
try:
|
||||
self._proc = await asyncio.create_subprocess_exec(
|
||||
*argv,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
)
|
||||
except Exception as exc:
|
||||
self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True)
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
|
||||
return
|
||||
|
||||
self._reader_task = self.loop.create_task(self._pump_proc_stdout())
|
||||
self._writer_task = self.loop.create_task(self._pump_proc_stdin())
|
||||
self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)}")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type == MSG_DATA:
|
||||
if frame.payload:
|
||||
try:
|
||||
self._stdin_queue.put_nowait(frame.payload)
|
||||
except Exception:
|
||||
await self._stdin_queue.put(frame.payload)
|
||||
elif frame.msg_type == MSG_CONTROL:
|
||||
await self._handle_control(frame.payload)
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
elif frame.msg_type == MSG_WINDOW_UPDATE:
|
||||
# Reserved for back-pressure; ignore for now.
|
||||
return
|
||||
|
||||
async def _handle_control(self, payload: bytes) -> None:
|
||||
# No-op for pipe mode; resize is not supported here.
|
||||
return
|
||||
|
||||
# -------------------- Pipe fallback pumps --------------------
|
||||
async def _pump_proc_stdout(self) -> None:
|
||||
try:
|
||||
while self._proc and not self._closed:
|
||||
chunk = await self._proc.stdout.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
frame = self._make_frame(MSG_DATA, payload=bytes(chunk))
|
||||
await self._send_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdout pump error", error=True)
|
||||
finally:
|
||||
if self._proc and not self._closed:
|
||||
try:
|
||||
self._exit_code = await self._proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
await self.stop(reason="stdout_closed")
|
||||
|
||||
async def _pump_proc_stdin(self) -> None:
|
||||
try:
|
||||
while self._proc and not self._closed:
|
||||
data = await self._stdin_queue.get()
|
||||
if self._closed or not self._proc or not self._proc.stdin:
|
||||
break
|
||||
try:
|
||||
self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
await self._proc.stdin.drain()
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
|
||||
finally:
|
||||
await self.stop(reason="stdin_closed")
|
||||
|
||||
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
current = asyncio.current_task()
|
||||
if self._reader_task and self._reader_task is not current:
|
||||
try:
|
||||
self._reader_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
if self._writer_task and self._writer_task is not current:
|
||||
try:
|
||||
self._writer_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
# Include exit code in the close reason for debugging.
|
||||
exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else ""
|
||||
close_reason = (reason or "powershell_exit") + exit_suffix
|
||||
# Always send CLOSE before socket teardown so engine/UI see the reason.
|
||||
try:
|
||||
await self._send_close(code, close_reason)
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps close send failed", error=True)
|
||||
self.role._log(
|
||||
f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={close_reason}"
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Namespace package for reverse tunnel domains (Agent side)."""
|
||||
|
||||
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Placeholder Bash channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class BashChannel:
|
||||
"""Stub Bash handler that immediately reports unsupported."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
# Until Bash support is implemented, close the channel to free resources.
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
# Ignore payloads but acknowledge by stopping the channel to avoid leaks.
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "bash_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel bash channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["BashChannel"]
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Expose the PowerShell channel under the domain path, with file-based import fallback."""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
powershell_module = None
|
||||
|
||||
# Attempt package-relative import first
|
||||
try: # pragma: no cover - best effort
|
||||
from ....ReverseTunnel import tunnel_Powershell as powershell_module # type: ignore
|
||||
except Exception:
|
||||
powershell_module = None
|
||||
|
||||
# Fallback: load directly from file path to survive non-package runtimes
|
||||
if powershell_module is None:
|
||||
try:
|
||||
base = Path(__file__).resolve().parents[3] / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
spec = importlib.util.spec_from_file_location("tunnel_Powershell", base)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
powershell_module = module
|
||||
except Exception:
|
||||
powershell_module = None
|
||||
|
||||
if powershell_module and hasattr(powershell_module, "PowershellChannel"):
|
||||
PowershellChannel = powershell_module.PowershellChannel # type: ignore
|
||||
else: # pragma: no cover - safety guard
|
||||
class PowershellChannel: # type: ignore
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError("PowerShell channel unavailable")
|
||||
|
||||
|
||||
__all__ = ["PowershellChannel"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for interactive shell tunnels (Agent side)."""
|
||||
|
||||
from .Powershell import PowershellChannel
|
||||
from .Bash import BashChannel
|
||||
|
||||
__all__ = ["PowershellChannel", "BashChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Interactive shell domain (PowerShell/Bash) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote interactive shell tunnel domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-interactive-shell"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder SSH channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class SSHChannel:
|
||||
"""Stub SSH handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "ssh_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel ssh channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["SSHChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder WinRM channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class WinRMChannel:
|
||||
"""Stub WinRM handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "winrm_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel winrm channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["WinRMChannel"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote management tunnels (Agent side)."""
|
||||
|
||||
from .SSH import SSHChannel
|
||||
from .WinRM import WinRMChannel
|
||||
|
||||
__all__ = ["SSHChannel", "WinRMChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Remote management domain (SSH/WinRM) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote management domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-management"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder RDP channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class RDPChannel:
|
||||
"""Stub RDP handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "rdp_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel rdp channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["RDPChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder VNC channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class VNCChannel:
|
||||
"""Stub VNC handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "vnc_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel vnc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["VNCChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder WebRTC channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class WebRTCChannel:
|
||||
"""Stub WebRTC handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "webrtc_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel webrtc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["WebRTCChannel"]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Protocol handlers for remote video tunnels (Agent side)."""
|
||||
|
||||
from .WebRTC import WebRTCChannel
|
||||
from .RDP import RDPChannel
|
||||
from .VNC import VNCChannel
|
||||
|
||||
__all__ = ["WebRTCChannel", "RDPChannel", "VNCChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Remote video/desktop domain (RDP/VNC/WebRTC) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote video domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-video"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,939 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
# Capture import errors for protocol handlers so we can report why they are missing.
|
||||
PS_IMPORT_ERROR: Optional[str] = None
|
||||
BASH_IMPORT_ERROR: Optional[str] = None
|
||||
tunnel_SSH = None
|
||||
tunnel_WinRM = None
|
||||
tunnel_VNC = None
|
||||
tunnel_RDP = None
|
||||
tunnel_WebRTC = None
|
||||
tunnel_Powershell = None
|
||||
tunnel_Bash = None
|
||||
|
||||
|
||||
def _load_protocol_module(module_name: str, rel_parts: list[str]) -> tuple[Optional[object], Optional[str]]:
|
||||
"""Load a protocol handler directly from a file path to survive non-package runtimes."""
|
||||
|
||||
base = Path(__file__).parent
|
||||
path = base
|
||||
for part in rel_parts:
|
||||
path = path / part
|
||||
if not path.exists():
|
||||
return None, f"path_missing:{path}"
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
if not spec or not spec.loader:
|
||||
return None, "spec_failed"
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module, None
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return None, repr(exc)
|
||||
|
||||
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Powershell as tunnel_Powershell # type: ignore
|
||||
except Exception as exc: # pragma: no cover - best-effort logging only
|
||||
PS_IMPORT_ERROR = repr(exc)
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_Powershell",
|
||||
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Powershell.py"],
|
||||
)
|
||||
if _module:
|
||||
tunnel_Powershell = _module
|
||||
PS_IMPORT_ERROR = None
|
||||
else:
|
||||
try:
|
||||
from .ReverseTunnel import tunnel_Powershell # type: ignore # legacy fallback
|
||||
PS_IMPORT_ERROR = None
|
||||
except Exception as exc2: # pragma: no cover - diagnostic only
|
||||
PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | legacy_fallback={exc2!r} | file_load_failed={_err}"
|
||||
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Bash as tunnel_Bash # type: ignore
|
||||
except Exception as exc: # pragma: no cover - best-effort logging only
|
||||
BASH_IMPORT_ERROR = repr(exc)
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_Bash",
|
||||
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Bash.py"],
|
||||
)
|
||||
if _module:
|
||||
tunnel_Bash = _module
|
||||
BASH_IMPORT_ERROR = None
|
||||
else:
|
||||
BASH_IMPORT_ERROR = f"{BASH_IMPORT_ERROR} | file_load_failed={_err}"
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_management.Protocols import SSH as tunnel_SSH # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_SSH",
|
||||
["Reverse_Tunnels", "remote_management", "Protocols", "SSH.py"],
|
||||
)
|
||||
tunnel_SSH = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_management.Protocols import WinRM as tunnel_WinRM # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_WinRM",
|
||||
["Reverse_Tunnels", "remote_management", "Protocols", "WinRM.py"],
|
||||
)
|
||||
tunnel_WinRM = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import VNC as tunnel_VNC # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_VNC",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "VNC.py"],
|
||||
)
|
||||
tunnel_VNC = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import RDP as tunnel_RDP # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_RDP",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "RDP.py"],
|
||||
)
|
||||
tunnel_RDP = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import WebRTC as tunnel_WebRTC # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_WebRTC",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "WebRTC.py"],
|
||||
)
|
||||
tunnel_WebRTC = _module
|
||||
|
||||
ROLE_NAME = "reverse_tunnel"
|
||||
ROLE_CONTEXTS = ["interactive", "system"]
|
||||
|
||||
# Message types (keep in sync with Engine service)
|
||||
MSG_CONNECT = 0x01
|
||||
MSG_CONNECT_ACK = 0x02
|
||||
MSG_CHANNEL_OPEN = 0x03
|
||||
MSG_CHANNEL_ACK = 0x04
|
||||
MSG_DATA = 0x05
|
||||
MSG_WINDOW_UPDATE = 0x06
|
||||
MSG_HEARTBEAT = 0x07
|
||||
MSG_CLOSE = 0x08
|
||||
MSG_CONTROL = 0x09
|
||||
|
||||
# Close codes
|
||||
CLOSE_OK = 0
|
||||
CLOSE_IDLE_TIMEOUT = 1
|
||||
CLOSE_GRACE_EXPIRED = 2
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AUTH_FAILED = 4
|
||||
CLOSE_SERVER_SHUTDOWN = 5
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
CLOSE_DOMAIN_LIMIT = 7
|
||||
CLOSE_UNEXPECTED_DISCONNECT = 8
|
||||
|
||||
FRAME_VERSION = 1
|
||||
FRAME_HEADER_STRUCT = struct.Struct("<BBBBII") # version, msg_type, flags, reserved, channel_id, length
|
||||
|
||||
|
||||
@dataclass
|
||||
class TunnelFrame:
|
||||
msg_type: int
|
||||
channel_id: int
|
||||
payload: bytes = field(default_factory=bytes)
|
||||
flags: int = 0
|
||||
version: int = FRAME_VERSION
|
||||
reserved: int = 0
|
||||
|
||||
def encode(self) -> bytes:
|
||||
payload = self.payload or b""
|
||||
header = FRAME_HEADER_STRUCT.pack(
|
||||
self.version,
|
||||
self.msg_type,
|
||||
self.flags,
|
||||
self.reserved,
|
||||
int(self.channel_id),
|
||||
len(payload),
|
||||
)
|
||||
return header + payload
|
||||
|
||||
|
||||
def decode_frame(raw: bytes) -> TunnelFrame:
|
||||
if len(raw) < FRAME_HEADER_STRUCT.size:
|
||||
raise ValueError("frame_too_small")
|
||||
version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0)
|
||||
if version != FRAME_VERSION:
|
||||
raise ValueError(f"unsupported_version:{version}")
|
||||
if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length:
|
||||
raise ValueError("invalid_length")
|
||||
payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length]
|
||||
if len(payload) != length:
|
||||
raise ValueError("length_mismatch")
|
||||
return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved)
|
||||
|
||||
|
||||
def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame:
|
||||
flags = 0x1 if is_ack else 0x0
|
||||
return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"")
|
||||
|
||||
|
||||
def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame:
|
||||
payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8")
|
||||
return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload)
|
||||
|
||||
|
||||
def _norm_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
try:
|
||||
return str(value).strip()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _is_literal_ip(value: str) -> bool:
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
ipaddress.ip_address(value.strip().strip("[]"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTunnel:
|
||||
tunnel_id: str
|
||||
domain: str
|
||||
protocol: str
|
||||
port: int
|
||||
token: str
|
||||
url: str
|
||||
heartbeat_seconds: int
|
||||
idle_seconds: int
|
||||
grace_seconds: int
|
||||
expires_at: Optional[int]
|
||||
signing_key_hint: Optional[str] = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
tasks: list = field(default_factory=list)
|
||||
send_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
channels: Dict[int, Any] = field(default_factory=dict)
|
||||
last_activity: float = field(default_factory=lambda: time.time())
|
||||
connected: bool = False
|
||||
stopping: bool = False
|
||||
stop_reason: Optional[str] = None
|
||||
stop_origin: Optional[str] = None
|
||||
|
||||
|
||||
class BaseChannel:
|
||||
"""Placeholder channel handler; protocol-specific handlers plug in later."""
|
||||
|
||||
def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
|
||||
async def start(self) -> None:
|
||||
# Nothing to prime for placeholder channels.
|
||||
return
|
||||
|
||||
async def on_frame(self, frame: TunnelFrame) -> None:
|
||||
# Drop frames until protocol module is provided.
|
||||
return
|
||||
|
||||
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
|
||||
await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason))
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, ctx):
|
||||
self.ctx = ctx
|
||||
self.sio = ctx.sio
|
||||
self.loop = ctx.loop or asyncio.get_event_loop()
|
||||
self.hooks = ctx.hooks or {}
|
||||
self._http_client_factory = self.hooks.get("http_client")
|
||||
self._log_hook = self.hooks.get("log_agent")
|
||||
self._active: Dict[str, ActiveTunnel] = {}
|
||||
self._domain_claims: Dict[str, str] = {}
|
||||
self._domain_limits: Dict[str, Optional[int]] = {
|
||||
"remote-interactive-shell": 2,
|
||||
"remote-management": 1,
|
||||
"remote-video": 2,
|
||||
# Legacy / protocol fallbacks
|
||||
"ps": 2,
|
||||
"rdp": 1,
|
||||
"vnc": 1,
|
||||
"webrtc": 2,
|
||||
"ssh": None,
|
||||
"winrm": None,
|
||||
}
|
||||
self._default_heartbeat = 20
|
||||
self._protocol_handlers: Dict[str, Any] = {}
|
||||
self._frame_cls = TunnelFrame
|
||||
self.close_frame = close_frame
|
||||
try:
|
||||
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
|
||||
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
|
||||
module_path = getattr(tunnel_Powershell, "__file__", None)
|
||||
self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}")
|
||||
else:
|
||||
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
|
||||
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
exists_hint = f" exists={module_path.exists()}"
|
||||
self._log(
|
||||
f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}",
|
||||
error=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True)
|
||||
try:
|
||||
if tunnel_Bash and hasattr(tunnel_Bash, "BashChannel"):
|
||||
self._protocol_handlers["bash"] = tunnel_Bash.BashChannel
|
||||
module_path = getattr(tunnel_Bash, "__file__", None)
|
||||
self._log(f"reverse_tunnel bash handler registered (BashChannel) module={module_path}")
|
||||
elif BASH_IMPORT_ERROR:
|
||||
self._log(f"reverse_tunnel bash handler NOT registered (missing module/class) import_error={BASH_IMPORT_ERROR}", error=True)
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel bash handler registration failed: {exc}", error=True)
|
||||
try:
|
||||
if tunnel_SSH and hasattr(tunnel_SSH, "SSHChannel"):
|
||||
self._protocol_handlers["ssh"] = tunnel_SSH.SSHChannel
|
||||
if tunnel_WinRM and hasattr(tunnel_WinRM, "WinRMChannel"):
|
||||
self._protocol_handlers["winrm"] = tunnel_WinRM.WinRMChannel
|
||||
if tunnel_VNC and hasattr(tunnel_VNC, "VNCChannel"):
|
||||
self._protocol_handlers["vnc"] = tunnel_VNC.VNCChannel
|
||||
if tunnel_RDP and hasattr(tunnel_RDP, "RDPChannel"):
|
||||
self._protocol_handlers["rdp"] = tunnel_RDP.RDPChannel
|
||||
if tunnel_WebRTC and hasattr(tunnel_WebRTC, "WebRTCChannel"):
|
||||
self._protocol_handlers["webrtc"] = tunnel_WebRTC.WebRTCChannel
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel protocol handler registration failed: {exc}", error=True)
|
||||
|
||||
# ------------------------------------------------------------------ Logging
|
||||
def _log(self, message: str, *, error: bool = False) -> None:
|
||||
fname = "reverse_tunnel.log"
|
||||
try:
|
||||
if callable(self._log_hook):
|
||||
self._log_hook(message, fname=fname)
|
||||
if error:
|
||||
self._log_hook(message, fname="agent.error.log")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ Event wiring
|
||||
def register_events(self):
|
||||
@self.sio.on("reverse_tunnel_start")
|
||||
async def _reverse_tunnel_start(payload):
|
||||
await self._handle_tunnel_start(payload)
|
||||
|
||||
@self.sio.on("reverse_tunnel_stop")
|
||||
async def _reverse_tunnel_stop(payload):
|
||||
tid = ""
|
||||
if isinstance(payload, dict):
|
||||
tid = _norm_text(payload.get("tunnel_id"))
|
||||
await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop")
|
||||
|
||||
# ------------------------------------------------------------------ Token helpers
|
||||
def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]:
|
||||
if not token:
|
||||
raise ValueError("token_missing")
|
||||
|
||||
def _b64decode(segment: str) -> bytes:
|
||||
padding = "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(segment + padding)
|
||||
|
||||
parts = token.split(".")
|
||||
payload_segment = parts[0]
|
||||
payload_bytes = _b64decode(payload_segment)
|
||||
try:
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
raise ValueError("token_decode_error") from exc
|
||||
|
||||
if len(parts) == 2:
|
||||
candidates = []
|
||||
hint = _norm_text(signing_key_hint)
|
||||
if hint:
|
||||
candidates.append(hint)
|
||||
client = self._http_client()
|
||||
if client and hasattr(client, "load_server_signing_key"):
|
||||
try:
|
||||
stored = client.load_server_signing_key()
|
||||
except Exception:
|
||||
stored = None
|
||||
if isinstance(stored, str) and stored.strip():
|
||||
candidates.append(stored.strip())
|
||||
|
||||
signature = _b64decode(parts[1])
|
||||
verified = False
|
||||
for candidate in candidates:
|
||||
try:
|
||||
key_bytes = base64.b64decode(candidate, validate=True)
|
||||
public_key = serialization.load_der_public_key(key_bytes)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
continue
|
||||
try:
|
||||
public_key.verify(signature, payload_bytes)
|
||||
verified = True
|
||||
if client and hasattr(client, "store_server_signing_key"):
|
||||
try:
|
||||
client.store_server_signing_key(candidate)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if not verified:
|
||||
raise ValueError("token_signature_invalid")
|
||||
return payload
|
||||
|
||||
def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]:
|
||||
payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint)
|
||||
|
||||
def _matches(expected: str, actual: Any) -> bool:
|
||||
return _norm_text(expected).lower() == _norm_text(actual).lower()
|
||||
|
||||
if expected_agent and not _matches(expected_agent, payload.get("agent_id")):
|
||||
raise ValueError("token_agent_mismatch")
|
||||
if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")):
|
||||
raise ValueError("token_id_mismatch")
|
||||
if expected_domain and not _matches(expected_domain, payload.get("domain")):
|
||||
raise ValueError("token_domain_mismatch")
|
||||
if expected_protocol and not _matches(expected_protocol, payload.get("protocol")):
|
||||
raise ValueError("token_protocol_mismatch")
|
||||
|
||||
expires_at = payload.get("expires_at")
|
||||
try:
|
||||
expires_ts = int(expires_at) if expires_at is not None else None
|
||||
except Exception:
|
||||
expires_ts = None
|
||||
if expires_ts is not None and expires_ts < int(time.time()):
|
||||
raise ValueError("token_expired")
|
||||
return payload
|
||||
|
||||
# ------------------------------------------------------------------ Utility
|
||||
def _http_client(self):
|
||||
try:
|
||||
if callable(self._http_client_factory):
|
||||
return self._http_client_factory()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _port_allowed(self, port: int) -> bool:
|
||||
return isinstance(port, int) and 1 <= port <= 65535
|
||||
|
||||
def _domain_allowed(self, domain: str) -> bool:
|
||||
limit = self._domain_limits.get(domain)
|
||||
if limit is None:
|
||||
return True
|
||||
active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping]
|
||||
pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active]
|
||||
count = len(set(active + pending))
|
||||
return count < limit
|
||||
|
||||
def _build_ws_url(self, port: int) -> str:
|
||||
client = self._http_client()
|
||||
if client:
|
||||
try:
|
||||
client.refresh_base_url()
|
||||
except Exception:
|
||||
pass
|
||||
base_url = getattr(client, "base_url", "") or "https://localhost:5000"
|
||||
else:
|
||||
base_url = "https://localhost:5000"
|
||||
parsed = urlparse(base_url)
|
||||
host = parsed.hostname or "localhost"
|
||||
scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws"
|
||||
return f"{scheme}://{host}:{port}/"
|
||||
|
||||
def _ssl_context(self, host: str):
|
||||
client = self._http_client()
|
||||
verify = getattr(getattr(client, "session", None), "verify", True)
|
||||
if verify is False:
|
||||
return False
|
||||
if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"):
|
||||
try:
|
||||
ctx = client.key_store.build_ssl_context()
|
||||
if ctx and _is_literal_ip(host):
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _emit_status(self, payload: Dict[str, Any]) -> None:
|
||||
try:
|
||||
await self.sio.emit("reverse_tunnel_status", payload)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _mark_activity(self, tunnel: ActiveTunnel) -> None:
|
||||
tunnel.last_activity = time.time()
|
||||
|
||||
# ------------------------------------------------------------------ Event handlers
|
||||
async def _handle_tunnel_start(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
self._log("reverse_tunnel_start ignored: payload not a dict", error=True)
|
||||
return
|
||||
tunnel_id = _norm_text(payload.get("tunnel_id"))
|
||||
token = _norm_text(payload.get("token"))
|
||||
port = payload.get("port") or payload.get("assigned_port")
|
||||
protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps"
|
||||
domain = _norm_text(payload.get("domain") or protocol).lower() or protocol
|
||||
heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20)
|
||||
idle_seconds = int(payload.get("idle_seconds") or 3600)
|
||||
grace_seconds = int(payload.get("grace_seconds") or 3600)
|
||||
signing_key_hint = _norm_text(payload.get("signing_key"))
|
||||
agent_hint = _norm_text(payload.get("agent_id"))
|
||||
|
||||
# Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts).
|
||||
if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower():
|
||||
return
|
||||
|
||||
if not token:
|
||||
self._log("reverse_tunnel_start rejected: missing token", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"})
|
||||
return
|
||||
|
||||
if not tunnel_id:
|
||||
tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias
|
||||
|
||||
try:
|
||||
claims = self._validate_token(
|
||||
token,
|
||||
expected_agent=self.ctx.agent_id,
|
||||
expected_domain=domain,
|
||||
expected_protocol=protocol,
|
||||
expected_tunnel=tunnel_id,
|
||||
signing_key_hint=signing_key_hint,
|
||||
)
|
||||
except Exception as exc:
|
||||
if str(exc) == "token_agent_mismatch":
|
||||
# Broadcast hit the wrong agent context; ignore quietly.
|
||||
return
|
||||
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
|
||||
return
|
||||
|
||||
tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id)
|
||||
if not tunnel_id:
|
||||
self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True)
|
||||
return
|
||||
|
||||
try:
|
||||
port = int(port or claims.get("assigned_port") or 0)
|
||||
except Exception:
|
||||
port = 0
|
||||
if not self._port_allowed(port):
|
||||
self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"})
|
||||
return
|
||||
|
||||
domain = _norm_text(claims.get("domain") or domain).lower()
|
||||
protocol = _norm_text(claims.get("protocol") or protocol).lower()
|
||||
expires_at = claims.get("expires_at")
|
||||
|
||||
if tunnel_id in self._active:
|
||||
self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}")
|
||||
return
|
||||
if not self._domain_allowed(domain):
|
||||
self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"})
|
||||
return
|
||||
|
||||
url = self._build_ws_url(port)
|
||||
parsed = urlparse(url)
|
||||
heartbeat_seconds = max(5, min(heartbeat_seconds, 120))
|
||||
idle_seconds = max(30, idle_seconds)
|
||||
grace_seconds = max(60, grace_seconds)
|
||||
|
||||
tunnel = ActiveTunnel(
|
||||
tunnel_id=tunnel_id,
|
||||
domain=domain,
|
||||
protocol=protocol,
|
||||
port=port,
|
||||
token=token,
|
||||
url=url,
|
||||
heartbeat_seconds=heartbeat_seconds,
|
||||
idle_seconds=idle_seconds,
|
||||
grace_seconds=grace_seconds,
|
||||
expires_at=int(expires_at) if expires_at is not None else None,
|
||||
signing_key_hint=signing_key_hint or None,
|
||||
)
|
||||
|
||||
self._active[tunnel_id] = tunnel
|
||||
self._domain_claims[domain] = tunnel_id
|
||||
self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}")
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url})
|
||||
task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost"))
|
||||
tunnel.tasks.append(task)
|
||||
|
||||
# ------------------------------------------------------------------ Core tunnel handling
|
||||
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
|
||||
ssl_ctx = self._ssl_context(host)
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
|
||||
self._log(
|
||||
f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} "
|
||||
f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}"
|
||||
)
|
||||
try:
|
||||
tunnel.session = aiohttp.ClientSession(timeout=timeout)
|
||||
tunnel.websocket = await tunnel.session.ws_connect(
|
||||
tunnel.url,
|
||||
ssl=ssl_ctx,
|
||||
heartbeat=None,
|
||||
max_msg_size=0,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._mark_activity(tunnel)
|
||||
self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}")
|
||||
await tunnel.websocket.send_bytes(
|
||||
TunnelFrame(
|
||||
msg_type=MSG_CONNECT,
|
||||
channel_id=0,
|
||||
payload=json.dumps(
|
||||
{
|
||||
"agent_id": self.ctx.agent_id,
|
||||
"tunnel_id": tunnel.tunnel_id,
|
||||
"token": tunnel.token,
|
||||
"protocol": tunnel.protocol,
|
||||
"domain": tunnel.domain,
|
||||
"version": FRAME_VERSION,
|
||||
},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8"),
|
||||
).encode()
|
||||
)
|
||||
self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
sender = self.loop.create_task(self._pump_sender(tunnel))
|
||||
receiver = self.loop.create_task(self._pump_receiver(tunnel))
|
||||
heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel))
|
||||
watchdog = self.loop.create_task(self._watchdog(tunnel))
|
||||
tunnel.tasks.extend([sender, receiver, heartbeats, watchdog])
|
||||
task_labels = {
|
||||
sender: "sender",
|
||||
receiver: "receiver",
|
||||
heartbeats: "heartbeat",
|
||||
watchdog: "watchdog",
|
||||
}
|
||||
done, pending = await asyncio.wait(task_labels.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
for finished in done:
|
||||
label = task_labels.get(finished) or "unknown"
|
||||
exc_text = ""
|
||||
try:
|
||||
exc_obj = finished.exception()
|
||||
except asyncio.CancelledError:
|
||||
exc_obj = None
|
||||
exc_text = " (cancelled)"
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
exc_obj = exc
|
||||
if exc_obj:
|
||||
exc_text = f" (exc={exc_obj!r})"
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = f"{label}_stopped{exc_text}"
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = label
|
||||
self._log(
|
||||
f"reverse_tunnel task completed tunnel_id={tunnel.tunnel_id} task={label} stop_reason={tunnel.stop_reason}{exc_text}"
|
||||
)
|
||||
if pending:
|
||||
try:
|
||||
self._log(
|
||||
"reverse_tunnel pending tasks after first completion tunnel_id=%s pending=%s",
|
||||
# Represent pending tasks by label for debugging.
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"})
|
||||
finally:
|
||||
try:
|
||||
ws = tunnel.websocket
|
||||
self._log(
|
||||
f"reverse_tunnel ws closing tunnel_id={tunnel.tunnel_id} "
|
||||
f"code={getattr(ws, 'close_code', None)} reason={getattr(ws, 'close_reason', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await self._shutdown_tunnel(tunnel)
|
||||
|
||||
async def _pump_sender(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
frame: TunnelFrame = await tunnel.send_queue.get()
|
||||
try:
|
||||
await tunnel.websocket.send_bytes(frame.encode())
|
||||
self._mark_activity(tunnel)
|
||||
self._log(
|
||||
f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} "
|
||||
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
|
||||
)
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "sender_error"
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "sender_failed"
|
||||
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
|
||||
ws = tunnel.websocket
|
||||
if ws is None:
|
||||
return
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
try:
|
||||
frame = decode_frame(msg.data)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
continue
|
||||
self._mark_activity(tunnel)
|
||||
await self._handle_frame(tunnel, frame)
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
|
||||
self._log(
|
||||
f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} "
|
||||
f"code={ws.close_code} reason={ws.close_reason}"
|
||||
)
|
||||
tunnel.stop_reason = ws.close_reason or "ws_closed"
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "receiver_failed"
|
||||
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}")
|
||||
# If no stop_reason was set, emit a CLOSE so engine/UI see a reason.
|
||||
if not tunnel.stop_reason:
|
||||
try:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_UNEXPECTED_DISCONNECT, "receiver_stop"))
|
||||
tunnel.stop_reason = "receiver_stop"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
await asyncio.sleep(tunnel.heartbeat_seconds)
|
||||
await self._send_frame(tunnel, heartbeat_frame())
|
||||
self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "heartbeat_failed"
|
||||
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
await asyncio.sleep(10)
|
||||
now = time.time()
|
||||
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
|
||||
tunnel.stop_reason = tunnel.stop_reason or "idle_timeout"
|
||||
self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}")
|
||||
break
|
||||
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
|
||||
tunnel.stop_reason = tunnel.stop_reason or "grace_expired"
|
||||
self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}")
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "watchdog_failed"
|
||||
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
self._log(
|
||||
f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} "
|
||||
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
|
||||
)
|
||||
if frame.msg_type == MSG_HEARTBEAT:
|
||||
if frame.flags & 0x1:
|
||||
self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}")
|
||||
return
|
||||
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
|
||||
return
|
||||
if frame.msg_type == MSG_CONNECT_ACK:
|
||||
tunnel.connected = True
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
|
||||
self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}")
|
||||
return
|
||||
if frame.msg_type == MSG_CHANNEL_OPEN:
|
||||
await self._handle_channel_open(tunnel, frame)
|
||||
return
|
||||
if frame.msg_type == MSG_CLOSE:
|
||||
try:
|
||||
reason = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"}
|
||||
tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None
|
||||
await self._shutdown_tunnel(tunnel, send_close=False)
|
||||
return
|
||||
|
||||
handler = tunnel.channels.get(frame.channel_id)
|
||||
if handler:
|
||||
try:
|
||||
await handler.on_frame(frame)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True)
|
||||
else:
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE):
|
||||
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel"))
|
||||
|
||||
async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
try:
|
||||
payload = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
payload = {}
|
||||
protocol = _norm_text(payload.get("protocol") or tunnel.protocol)
|
||||
metadata = payload.get("metadata") if isinstance(payload, dict) else {}
|
||||
|
||||
if frame.channel_id in tunnel.channels:
|
||||
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
|
||||
return
|
||||
|
||||
if protocol.lower() == "ps" and "ps" not in self._protocol_handlers:
|
||||
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
|
||||
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
exists_hint = f" exists={module_path.exists()}"
|
||||
self._log(
|
||||
f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}",
|
||||
error=True,
|
||||
)
|
||||
|
||||
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
|
||||
try:
|
||||
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel channel handler fallback to BaseChannel protocol={protocol}", error=True)
|
||||
handler = BaseChannel(self, tunnel, frame.channel_id, metadata)
|
||||
tunnel.channels[frame.channel_id] = handler
|
||||
await handler.start()
|
||||
await self._send_frame(
|
||||
tunnel,
|
||||
TunnelFrame(
|
||||
msg_type=MSG_CHANNEL_ACK,
|
||||
channel_id=frame.channel_id,
|
||||
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
|
||||
),
|
||||
)
|
||||
self._log(
|
||||
f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} "
|
||||
f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}"
|
||||
)
|
||||
|
||||
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
if tunnel.stopping and getattr(frame, "msg_type", None) != MSG_CLOSE:
|
||||
return
|
||||
try:
|
||||
tunnel.send_queue.put_nowait(frame)
|
||||
except Exception:
|
||||
await tunnel.send_queue.put(frame)
|
||||
|
||||
async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None:
|
||||
tunnel = self._active.get(tunnel_id)
|
||||
if not tunnel:
|
||||
return
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = "stop_tunnel"
|
||||
self._log(f"reverse_tunnel stop_tunnel requested tunnel_id={tunnel_id} code={code} reason={reason}")
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = reason or "requested"
|
||||
await self._send_frame(tunnel, close_frame(0, code, reason))
|
||||
await self._shutdown_tunnel(tunnel, send_close=False)
|
||||
|
||||
async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None:
|
||||
if tunnel.stopping:
|
||||
return
|
||||
tunnel.stopping = True
|
||||
reason_text = tunnel.stop_reason or "closed"
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = reason_text
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = "shutdown"
|
||||
self._log(
|
||||
f"reverse_tunnel shutdown start tunnel_id={tunnel.tunnel_id} stop_reason={tunnel.stop_reason} "
|
||||
f"stop_origin={tunnel.stop_origin} ws_closed={getattr(tunnel.websocket, 'closed', None)}"
|
||||
)
|
||||
# Stop all channels first so CLOSE frames (with reasons) are sent upstream.
|
||||
for handler in list(tunnel.channels.values()):
|
||||
try:
|
||||
await handler.stop(code=CLOSE_UNEXPECTED_DISCONNECT, reason=reason_text or "tunnel_shutdown")
|
||||
except Exception:
|
||||
pass
|
||||
if send_close:
|
||||
close_payload = close_frame(0, CLOSE_AGENT_SHUTDOWN, reason_text or "agent_shutdown")
|
||||
try:
|
||||
await self._send_frame(tunnel, close_payload)
|
||||
# Give the sender loop a brief window to flush the CLOSE upstream.
|
||||
await asyncio.sleep(0.05)
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: if sender task died, try sending directly on the websocket.
|
||||
try:
|
||||
if tunnel.websocket and not tunnel.websocket.closed:
|
||||
await tunnel.websocket.send_bytes(close_payload.encode())
|
||||
except Exception:
|
||||
pass
|
||||
for task in list(tunnel.tasks):
|
||||
try:
|
||||
task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
if tunnel.websocket is not None:
|
||||
try:
|
||||
message = (reason_text or "agent_shutdown").encode("utf-8", "ignore")[:120]
|
||||
await tunnel.websocket.close(message=message)
|
||||
except Exception:
|
||||
pass
|
||||
if tunnel.session is not None:
|
||||
try:
|
||||
await tunnel.session.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._active.pop(tunnel.tunnel_id, None)
|
||||
if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id:
|
||||
self._domain_claims.pop(tunnel.domain, None)
|
||||
self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}")
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"})
|
||||
|
||||
# ------------------------------------------------------------------ Lifecycle
|
||||
def stop_all(self):
|
||||
for tunnel_id in list(self._active.keys()):
|
||||
try:
|
||||
self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown"))
|
||||
except Exception:
|
||||
pass
|
||||
167
Data/Agent/Roles/role_VpnShell.py
Normal file
167
Data/Agent/Roles/role_VpnShell.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# ======================================================
|
||||
# Data\Agent\Roles\role_VpnShell.py
|
||||
# Description: PowerShell TCP server for VPN shell access (Engine connects over WireGuard /32).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""VPN PowerShell server for the WireGuard tunnel."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
import os
|
||||
|
||||
ROLE_NAME = "VpnShell"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
|
||||
|
||||
def _log_path() -> Path:
|
||||
root = Path(__file__).resolve().parents[2] / "Logs"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root / "reverse_tunnel.log"
|
||||
|
||||
|
||||
def _write_log(message: str) -> None:
|
||||
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
||||
try:
|
||||
_log_path().open("a", encoding="utf-8").write(f"[{ts}] [vpn-shell] {message}\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _b64encode(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode("ascii").strip()
|
||||
|
||||
|
||||
def _b64decode(value: str) -> bytes:
|
||||
return base64.b64decode(value.encode("ascii"))
|
||||
|
||||
def _resolve_shell_port() -> int:
|
||||
raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT")
|
||||
try:
|
||||
value = int(raw) if raw is not None else 47001
|
||||
except Exception:
|
||||
value = 47001
|
||||
if value < 1 or value > 65535:
|
||||
return 47001
|
||||
return value
|
||||
|
||||
|
||||
class ShellSession:
|
||||
def __init__(self, conn: socket.socket, address: tuple[str, int]) -> None:
|
||||
self.conn = conn
|
||||
self.address = address
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
self._stop = threading.Event()
|
||||
|
||||
def start(self) -> None:
|
||||
self.proc = subprocess.Popen(
|
||||
["powershell.exe", "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
bufsize=0,
|
||||
)
|
||||
threading.Thread(target=self._reader_loop, daemon=True).start()
|
||||
self._writer_loop()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
if not self.proc or not self.proc.stdout:
|
||||
return
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
chunk = self.proc.stdout.readline()
|
||||
if not chunk:
|
||||
break
|
||||
payload = json.dumps({"type": "stdout", "data": _b64encode(chunk)})
|
||||
self.conn.sendall(payload.encode("utf-8") + b"\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _writer_loop(self) -> None:
|
||||
buffer = b""
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
data = self.conn.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(line.decode("utf-8"))
|
||||
except Exception:
|
||||
continue
|
||||
if msg.get("type") == "stdin":
|
||||
payload = msg.get("data") or ""
|
||||
if self.proc and self.proc.stdin:
|
||||
try:
|
||||
self.proc.stdin.write(_b64decode(str(payload)))
|
||||
self.proc.stdin.flush()
|
||||
except Exception:
|
||||
pass
|
||||
if msg.get("type") == "close":
|
||||
self._stop.set()
|
||||
break
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self._stop.set()
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.proc:
|
||||
try:
|
||||
self.proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ShellServer:
|
||||
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None:
|
||||
self.host = host
|
||||
self.port = port or _resolve_shell_port()
|
||||
self._thread = threading.Thread(target=self._serve, daemon=True)
|
||||
self._thread.start()
|
||||
_write_log(f"VPN shell server listening on {self.host}:{self.port}")
|
||||
|
||||
def _serve(self) -> None:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind((self.host, self.port))
|
||||
server.listen(5)
|
||||
while True:
|
||||
conn, addr = server.accept()
|
||||
remote_ip = addr[0]
|
||||
if not remote_ip.startswith("10.255."):
|
||||
_write_log(f"Rejected shell connection from {remote_ip}")
|
||||
conn.close()
|
||||
continue
|
||||
_write_log(f"Accepted shell connection from {remote_ip}")
|
||||
session = ShellSession(conn, addr)
|
||||
threading.Thread(target=session.start, daemon=True).start()
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.server = ShellServer()
|
||||
|
||||
def register_events(self) -> None:
|
||||
return
|
||||
|
||||
def stop_all(self) -> None:
|
||||
return
|
||||
@@ -9,13 +9,15 @@
|
||||
|
||||
This role prepares the WireGuard client config, manages a single active
|
||||
session, enforces idle teardown, and logs lifecycle events to
|
||||
Agent/Logs/reverse_tunnel.log. It does not yet bind to engine signals; higher
|
||||
layers should call start_session/stop_session with the issued config/token.
|
||||
Agent/Logs/reverse_tunnel.log. It binds to Engine Socket.IO events
|
||||
(`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop
|
||||
the client session with the issued config/token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
@@ -26,6 +28,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
from signature_utils import verify_and_store_script_signature
|
||||
|
||||
ROLE_NAME = "WireGuardTunnel"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
@@ -95,6 +98,8 @@ class SessionConfig:
|
||||
allowed_ports: str
|
||||
idle_seconds: int = 900
|
||||
preshared_key: Optional[str] = None
|
||||
client_private_key: Optional[str] = None
|
||||
client_public_key: Optional[str] = None
|
||||
|
||||
|
||||
class WireGuardClient:
|
||||
@@ -122,17 +127,35 @@ class WireGuardClient:
|
||||
return candidate
|
||||
return "wireguard.exe"
|
||||
|
||||
def _validate_token(self, token: Dict[str, Any]) -> None:
|
||||
required = ("agent_id", "tunnel_id", "expires_at")
|
||||
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
|
||||
payload = dict(token or {})
|
||||
signature = payload.pop("signature", None)
|
||||
signing_key = payload.pop("signing_key", None)
|
||||
sig_alg = payload.pop("sig_alg", None)
|
||||
|
||||
required = ("agent_id", "tunnel_id", "expires_at", "port")
|
||||
missing = [field for field in required if field not in token or token[field] in ("", None)]
|
||||
if missing:
|
||||
raise ValueError(f"Missing token fields: {', '.join(missing)}")
|
||||
try:
|
||||
exp = float(token["expires_at"])
|
||||
exp = float(payload["expires_at"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid token expiry")
|
||||
if exp <= time.time():
|
||||
raise ValueError("Token expired")
|
||||
try:
|
||||
port = int(payload["port"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid token port")
|
||||
if port < 1 or port > 65535:
|
||||
raise ValueError("Invalid token port")
|
||||
|
||||
if signature:
|
||||
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
|
||||
raise ValueError("Unsupported token signature algorithm")
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key):
|
||||
raise ValueError("Token signature invalid")
|
||||
|
||||
def _run(self, args: list[str]) -> tuple[int, str, str]:
|
||||
try:
|
||||
@@ -142,9 +165,10 @@ class WireGuardClient:
|
||||
return 1, "", str(exc)
|
||||
|
||||
def _render_config(self, session: SessionConfig) -> str:
|
||||
private_key = session.client_private_key or self._client_keys["private"]
|
||||
lines = [
|
||||
"[Interface]",
|
||||
f"PrivateKey = {self._client_keys['private']}",
|
||||
f"PrivateKey = {private_key}",
|
||||
f"Address = {session.virtual_ip}",
|
||||
"",
|
||||
"[Peer]",
|
||||
@@ -172,13 +196,13 @@ class WireGuardClient:
|
||||
t.start()
|
||||
self._idle_thread = t
|
||||
|
||||
def start_session(self, session: SessionConfig) -> None:
|
||||
def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None:
|
||||
if self.session:
|
||||
_write_log("Rejecting start_session: existing session already active.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._validate_token(session.token)
|
||||
self._validate_token(session.token, signing_client=signing_client)
|
||||
except Exception as exc:
|
||||
_write_log(f"Refusing to start WireGuard session: {exc}")
|
||||
return
|
||||
@@ -243,6 +267,7 @@ class Role:
|
||||
self.client = client
|
||||
hooks = getattr(ctx, "hooks", {}) or {}
|
||||
self._log_hook = hooks.get("log_agent")
|
||||
self._http_client_factory = hooks.get("http_client")
|
||||
|
||||
def _log(self, message: str, *, error: bool = False) -> None:
|
||||
if callable(self._log_hook):
|
||||
@@ -254,6 +279,14 @@ class Role:
|
||||
pass
|
||||
_write_log(message)
|
||||
|
||||
def _http_client(self) -> Optional[Any]:
|
||||
try:
|
||||
if callable(self._http_client_factory):
|
||||
return self._http_client_factory()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
|
||||
if not isinstance(payload, dict):
|
||||
self._log("WireGuard start payload missing/invalid.", error=True)
|
||||
@@ -299,6 +332,8 @@ class Role:
|
||||
allowed_ports=allowed_ports,
|
||||
idle_seconds=idle_seconds,
|
||||
preshared_key=payload.get("preshared_key"),
|
||||
client_private_key=payload.get("client_private_key"),
|
||||
client_public_key=payload.get("client_public_key"),
|
||||
)
|
||||
|
||||
def register_events(self) -> None:
|
||||
@@ -310,7 +345,7 @@ class Role:
|
||||
if not session:
|
||||
return
|
||||
self._log("WireGuard start request received.")
|
||||
self.client.start_session(session)
|
||||
self.client.start_session(session, signing_client=self._http_client())
|
||||
|
||||
@sio.on("vpn_tunnel_stop")
|
||||
async def _vpn_tunnel_stop(payload):
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\Unit_Tests\test_reverse_tunnel.py
|
||||
# Description: Validates reverse tunnel lease API basics (allocation, token contents, and domain limit).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/request
|
||||
# ======================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import EngineTestHarness
|
||||
|
||||
|
||||
def _client_with_admin_session(harness: EngineTestHarness):
|
||||
client = harness.app.test_client()
|
||||
with client.session_transaction() as sess:
|
||||
sess["username"] = "admin"
|
||||
sess["role"] = "Admin"
|
||||
return client
|
||||
|
||||
|
||||
def _decode_token_segment(token: str) -> dict:
|
||||
"""Decode the unsigned payload segment from the tunnel token."""
|
||||
if not token:
|
||||
return {}
|
||||
segment = token.split(".")[0]
|
||||
padding = "=" * (-len(segment) % 4)
|
||||
raw = base64.urlsafe_b64decode(segment + padding)
|
||||
try:
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("agent_id", ["test-device-agent"])
|
||||
def test_tunnel_request_happy_path(engine_harness: EngineTestHarness, agent_id: str) -> None:
|
||||
client = _client_with_admin_session(engine_harness)
|
||||
resp = client.post(
|
||||
"/api/tunnel/request",
|
||||
json={"agent_id": agent_id, "protocol": "ps", "domain": "ps"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
payload = resp.get_json()
|
||||
assert payload["agent_id"] == agent_id
|
||||
assert payload["protocol"] == "ps"
|
||||
assert payload["domain"] == "ps"
|
||||
assert isinstance(payload["port"], int) and payload["port"] >= 30000
|
||||
assert payload.get("token")
|
||||
|
||||
claims = _decode_token_segment(payload["token"])
|
||||
assert claims.get("agent_id") == agent_id
|
||||
assert claims.get("protocol") == "ps"
|
||||
assert claims.get("domain") == "ps"
|
||||
assert claims.get("tunnel_id") == payload["tunnel_id"]
|
||||
assert claims.get("assigned_port") == payload["port"]
|
||||
|
||||
|
||||
def test_tunnel_request_domain_limit(engine_harness: EngineTestHarness) -> None:
|
||||
client = _client_with_admin_session(engine_harness)
|
||||
first = client.post(
|
||||
"/api/tunnel/request",
|
||||
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
|
||||
)
|
||||
assert first.status_code == 200
|
||||
|
||||
second = client.post(
|
||||
"/api/tunnel/request",
|
||||
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
|
||||
)
|
||||
assert second.status_code == 409
|
||||
data = second.get_json()
|
||||
assert data.get("error") == "domain_limit"
|
||||
|
||||
|
||||
def test_tunnel_request_includes_timeouts(engine_harness: EngineTestHarness) -> None:
|
||||
client = _client_with_admin_session(engine_harness)
|
||||
resp = client.post(
|
||||
"/api/tunnel/request",
|
||||
json={"agent_id": "test-device-agent", "protocol": "ps", "domain": "ps"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
payload = resp.get_json()
|
||||
assert payload.get("idle_seconds") and payload["idle_seconds"] > 0
|
||||
assert payload.get("grace_seconds") and payload["grace_seconds"] > 0
|
||||
assert payload.get("expires_at") and int(payload["expires_at"]) > 0
|
||||
@@ -1,101 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\Unit_Tests\test_reverse_tunnel_integration.py
|
||||
# Description: Integration test that exercises a full reverse tunnel PowerShell round-trip
|
||||
# against a running Engine + Agent (requires live services).
|
||||
#
|
||||
# Requirements:
|
||||
# - Environment variables must be set to point at a live Engine + Agent:
|
||||
# TUNNEL_TEST_HOST (e.g., https://localhost:5000)
|
||||
# TUNNEL_TEST_AGENT_ID (agent_id/agent_guid for the target device)
|
||||
# TUNNEL_TEST_BEARER (Authorization bearer token for an admin/operator)
|
||||
# - A live Agent must be reachable and allowed to establish the reverse tunnel.
|
||||
# - TLS verification can be controlled via TUNNEL_TEST_VERIFY ("false" to disable).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/request
|
||||
# - Socket.IO namespace /tunnel (join, ps_open, ps_send, ps_poll)
|
||||
# ======================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import socketio
|
||||
|
||||
|
||||
HOST = os.environ.get("TUNNEL_TEST_HOST", "").strip()
|
||||
AGENT_ID = os.environ.get("TUNNEL_TEST_AGENT_ID", "").strip()
|
||||
BEARER = os.environ.get("TUNNEL_TEST_BEARER", "").strip()
|
||||
VERIFY_ENV = os.environ.get("TUNNEL_TEST_VERIFY", "").strip().lower()
|
||||
VERIFY = False if VERIFY_ENV in {"false", "0", "no"} else True
|
||||
|
||||
SKIP_MSG = (
|
||||
"Live tunnel test skipped (set TUNNEL_TEST_HOST, TUNNEL_TEST_AGENT_ID, TUNNEL_TEST_BEARER to run)"
|
||||
)
|
||||
|
||||
|
||||
def _require_env() -> None:
|
||||
if not HOST or not AGENT_ID or not BEARER:
|
||||
pytest.skip(SKIP_MSG)
|
||||
|
||||
|
||||
def _make_session() -> requests.Session:
|
||||
sess = requests.Session()
|
||||
sess.verify = VERIFY
|
||||
sess.headers.update({"Authorization": f"Bearer {BEARER}"})
|
||||
return sess
|
||||
|
||||
|
||||
def test_reverse_tunnel_powershell_roundtrip() -> None:
|
||||
_require_env()
|
||||
sess = _make_session()
|
||||
|
||||
# 1) Request a tunnel lease
|
||||
resp = sess.post(
|
||||
f"{HOST}/api/tunnel/request",
|
||||
json={"agent_id": AGENT_ID, "protocol": "ps", "domain": "remote-interactive-shell"},
|
||||
)
|
||||
assert resp.status_code == 200, f"lease request failed: {resp.status_code} {resp.text}"
|
||||
lease = resp.json()
|
||||
tunnel_id = lease["tunnel_id"]
|
||||
|
||||
# 2) Connect to Socket.IO /tunnel namespace
|
||||
sio = socketio.Client()
|
||||
sio.connect(
|
||||
HOST,
|
||||
namespaces=["/tunnel"],
|
||||
headers={"Authorization": f"Bearer {BEARER}"},
|
||||
transports=["websocket"],
|
||||
wait_timeout=10,
|
||||
)
|
||||
|
||||
# 3) Join tunnel and open PS channel
|
||||
join_resp = sio.call("join", {"tunnel_id": tunnel_id}, namespace="/tunnel", timeout=10)
|
||||
assert join_resp.get("status") == "ok", f"join failed: {join_resp}"
|
||||
|
||||
open_resp = sio.call("ps_open", {"cols": 120, "rows": 32}, namespace="/tunnel", timeout=10)
|
||||
assert not open_resp.get("error"), f"ps_open failed: {open_resp}"
|
||||
|
||||
# 4) Send a command
|
||||
send_resp = sio.call("ps_send", {"data": 'Write-Host "Hello World"\r\n'}, namespace="/tunnel", timeout=10)
|
||||
assert not send_resp.get("error"), f"ps_send failed: {send_resp}"
|
||||
|
||||
# 5) Poll for output
|
||||
output_text = ""
|
||||
deadline = time.time() + 15
|
||||
while time.time() < deadline:
|
||||
poll_resp = sio.call("ps_poll", {}, namespace="/tunnel", timeout=10)
|
||||
if poll_resp.get("error"):
|
||||
pytest.fail(f"ps_poll failed: {poll_resp}")
|
||||
lines = poll_resp.get("output") or []
|
||||
output_text += "".join(lines)
|
||||
if "Hello World" in output_text:
|
||||
break
|
||||
time.sleep(0.5)
|
||||
|
||||
sio.disconnect()
|
||||
|
||||
assert "Hello World" in output_text, f"expected command output not found; got: {output_text!r}"
|
||||
@@ -77,17 +77,12 @@ LOG_ROOT = PROJECT_ROOT / "Engine" / "Logs"
|
||||
LOG_FILE_PATH = LOG_ROOT / "engine.log"
|
||||
ERROR_LOG_FILE_PATH = LOG_ROOT / "error.log"
|
||||
API_LOG_FILE_PATH = LOG_ROOT / "api.log"
|
||||
REVERSE_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
|
||||
|
||||
DEFAULT_TUNNEL_FIXED_PORT = 8443
|
||||
DEFAULT_TUNNEL_PORT_RANGE = (30000, 40000)
|
||||
DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS = 3600
|
||||
DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS = 3600
|
||||
DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS = 20
|
||||
VPN_TUNNEL_LOG_FILE_PATH = LOG_ROOT / "reverse_tunnel.log"
|
||||
DEFAULT_WIREGUARD_PORT = 30000
|
||||
DEFAULT_WIREGUARD_ENGINE_VIRTUAL_IP = "10.255.0.1/32"
|
||||
DEFAULT_WIREGUARD_PEER_NETWORK = "10.255.0.0/24"
|
||||
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478)
|
||||
DEFAULT_WIREGUARD_SHELL_PORT = 47001
|
||||
DEFAULT_WIREGUARD_ACL_WINDOWS = (3389, 5985, 5986, 5900, 3478, DEFAULT_WIREGUARD_SHELL_PORT)
|
||||
VPN_SERVER_CERT_ROOT = PROJECT_ROOT / "Engine" / "Certificates" / "VPN_Server"
|
||||
|
||||
|
||||
@@ -282,18 +277,14 @@ class EngineSettings:
|
||||
error_log_file: str
|
||||
api_log_file: str
|
||||
api_groups: Tuple[str, ...]
|
||||
reverse_tunnel_fixed_port: int
|
||||
reverse_tunnel_port_range: Tuple[int, int]
|
||||
reverse_tunnel_idle_timeout_seconds: int
|
||||
reverse_tunnel_grace_timeout_seconds: int
|
||||
reverse_tunnel_heartbeat_seconds: int
|
||||
reverse_tunnel_log_file: str
|
||||
vpn_tunnel_log_file: str
|
||||
wireguard_port: int
|
||||
wireguard_engine_virtual_ip: str
|
||||
wireguard_peer_network: str
|
||||
wireguard_server_private_key_path: str
|
||||
wireguard_server_public_key_path: str
|
||||
wireguard_acl_allowlist_windows: Tuple[int, ...]
|
||||
wireguard_shell_port: int
|
||||
raw: MutableMapping[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_flask_config(self) -> MutableMapping[str, Any]:
|
||||
@@ -390,10 +381,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
|
||||
api_log_file = str(runtime_config.get("API_LOG_FILE") or API_LOG_FILE_PATH)
|
||||
_ensure_parent(Path(api_log_file))
|
||||
|
||||
reverse_tunnel_log_file = str(
|
||||
runtime_config.get("REVERSE_TUNNEL_LOG_FILE") or REVERSE_TUNNEL_LOG_FILE_PATH
|
||||
vpn_tunnel_log_file = str(
|
||||
runtime_config.get("VPN_TUNNEL_LOG_FILE")
|
||||
or runtime_config.get("WIREGUARD_LOG_FILE")
|
||||
or os.environ.get("BOREALIS_VPN_TUNNEL_LOG_FILE")
|
||||
or os.environ.get("BOREALIS_WIREGUARD_LOG_FILE")
|
||||
or VPN_TUNNEL_LOG_FILE_PATH
|
||||
)
|
||||
_ensure_parent(Path(reverse_tunnel_log_file))
|
||||
_ensure_parent(Path(vpn_tunnel_log_file))
|
||||
|
||||
wireguard_port = _parse_int(
|
||||
runtime_config.get("WIREGUARD_PORT") or os.environ.get("BOREALIS_WIREGUARD_PORT"),
|
||||
@@ -416,6 +411,13 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
|
||||
or os.environ.get("BOREALIS_WIREGUARD_WINDOWS_ALLOWLIST"),
|
||||
default=DEFAULT_WIREGUARD_ACL_WINDOWS,
|
||||
)
|
||||
wireguard_shell_port = _parse_int(
|
||||
runtime_config.get("WIREGUARD_SHELL_PORT")
|
||||
or os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT"),
|
||||
default=DEFAULT_WIREGUARD_SHELL_PORT,
|
||||
minimum=1,
|
||||
maximum=65535,
|
||||
)
|
||||
wireguard_key_root = Path(
|
||||
runtime_config.get("WIREGUARD_KEY_ROOT")
|
||||
or os.environ.get("BOREALIS_WIREGUARD_KEY_ROOT")
|
||||
@@ -440,35 +442,6 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
|
||||
"scheduled_jobs",
|
||||
)
|
||||
|
||||
tunnel_fixed_port = _parse_int(
|
||||
runtime_config.get("TUNNEL_FIXED_PORT") or os.environ.get("BOREALIS_TUNNEL_FIXED_PORT"),
|
||||
default=DEFAULT_TUNNEL_FIXED_PORT,
|
||||
minimum=1,
|
||||
maximum=65535,
|
||||
)
|
||||
tunnel_port_range = _parse_port_range(
|
||||
runtime_config.get("TUNNEL_PORT_RANGE") or os.environ.get("BOREALIS_TUNNEL_PORT_RANGE"),
|
||||
default=DEFAULT_TUNNEL_PORT_RANGE,
|
||||
)
|
||||
tunnel_idle_timeout_seconds = _parse_int(
|
||||
runtime_config.get("TUNNEL_IDLE_TIMEOUT_SECONDS")
|
||||
or os.environ.get("BOREALIS_TUNNEL_IDLE_TIMEOUT_SECONDS"),
|
||||
default=DEFAULT_TUNNEL_IDLE_TIMEOUT_SECONDS,
|
||||
minimum=60,
|
||||
)
|
||||
tunnel_grace_timeout_seconds = _parse_int(
|
||||
runtime_config.get("TUNNEL_GRACE_TIMEOUT_SECONDS")
|
||||
or os.environ.get("BOREALIS_TUNNEL_GRACE_TIMEOUT_SECONDS"),
|
||||
default=DEFAULT_TUNNEL_GRACE_TIMEOUT_SECONDS,
|
||||
minimum=60,
|
||||
)
|
||||
tunnel_heartbeat_seconds = _parse_int(
|
||||
runtime_config.get("TUNNEL_HEARTBEAT_SECONDS")
|
||||
or os.environ.get("BOREALIS_TUNNEL_HEARTBEAT_SECONDS"),
|
||||
default=DEFAULT_TUNNEL_HEARTBEAT_INTERVAL_SECONDS,
|
||||
minimum=5,
|
||||
)
|
||||
|
||||
settings = EngineSettings(
|
||||
database_path=database_path,
|
||||
static_folder=static_folder,
|
||||
@@ -484,18 +457,14 @@ def load_runtime_config(overrides: Optional[Mapping[str, Any]] = None) -> Engine
|
||||
error_log_file=str(error_log_file),
|
||||
api_log_file=str(api_log_file),
|
||||
api_groups=api_groups,
|
||||
reverse_tunnel_fixed_port=tunnel_fixed_port,
|
||||
reverse_tunnel_port_range=tunnel_port_range,
|
||||
reverse_tunnel_idle_timeout_seconds=tunnel_idle_timeout_seconds,
|
||||
reverse_tunnel_grace_timeout_seconds=tunnel_grace_timeout_seconds,
|
||||
reverse_tunnel_heartbeat_seconds=tunnel_heartbeat_seconds,
|
||||
reverse_tunnel_log_file=reverse_tunnel_log_file,
|
||||
vpn_tunnel_log_file=vpn_tunnel_log_file,
|
||||
wireguard_port=wireguard_port,
|
||||
wireguard_engine_virtual_ip=wireguard_engine_virtual_ip,
|
||||
wireguard_peer_network=wireguard_peer_network,
|
||||
wireguard_server_private_key_path=wireguard_server_private_key_path,
|
||||
wireguard_server_public_key_path=wireguard_server_public_key_path,
|
||||
wireguard_acl_allowlist_windows=wireguard_acl_allowlist_windows,
|
||||
wireguard_shell_port=wireguard_shell_port,
|
||||
raw=runtime_config,
|
||||
)
|
||||
return settings
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# Data\Engine\database_migrations.py
|
||||
# Description: Provides schema evolution helpers for the Engine sqlite
|
||||
# database without importing the legacy ``Modules`` package.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Engine database schema migration helpers."""
|
||||
@@ -24,6 +26,7 @@ def apply_all(conn: sqlite3.Connection) -> None:
|
||||
|
||||
_ensure_devices_table(conn)
|
||||
_ensure_device_aux_tables(conn)
|
||||
_ensure_device_vpn_config_table(conn)
|
||||
_ensure_refresh_token_table(conn)
|
||||
_ensure_install_code_table(conn)
|
||||
_ensure_install_code_persistence_table(conn)
|
||||
@@ -112,6 +115,20 @@ def _ensure_device_aux_tables(conn: sqlite3.Connection) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_device_vpn_config_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS device_vpn_config (
|
||||
agent_id TEXT PRIMARY KEY,
|
||||
allowed_ports TEXT,
|
||||
updated_at TEXT,
|
||||
updated_by TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _ensure_refresh_token_table(conn: sqlite3.Connection) -> None:
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
|
||||
@@ -120,18 +120,14 @@ class EngineContext:
|
||||
config: Mapping[str, Any]
|
||||
api_groups: Sequence[str]
|
||||
api_log_path: str
|
||||
reverse_tunnel_fixed_port: int
|
||||
reverse_tunnel_port_range: Tuple[int, int]
|
||||
reverse_tunnel_idle_timeout_seconds: int
|
||||
reverse_tunnel_grace_timeout_seconds: int
|
||||
reverse_tunnel_heartbeat_seconds: int
|
||||
reverse_tunnel_log_path: str
|
||||
vpn_tunnel_log_path: str
|
||||
wireguard_port: int
|
||||
wireguard_engine_virtual_ip: str
|
||||
wireguard_peer_network: str
|
||||
wireguard_server_private_key_path: str
|
||||
wireguard_server_public_key_path: str
|
||||
wireguard_acl_allowlist_windows: Tuple[int, ...]
|
||||
wireguard_shell_port: int
|
||||
wireguard_server_manager: Optional[Any] = None
|
||||
assembly_cache: Optional[Any] = None
|
||||
|
||||
@@ -151,18 +147,14 @@ def _build_engine_context(settings: EngineSettings, logger: logging.Logger) -> E
|
||||
config=settings.as_dict(),
|
||||
api_groups=settings.api_groups,
|
||||
api_log_path=settings.api_log_file,
|
||||
reverse_tunnel_fixed_port=settings.reverse_tunnel_fixed_port,
|
||||
reverse_tunnel_port_range=settings.reverse_tunnel_port_range,
|
||||
reverse_tunnel_idle_timeout_seconds=settings.reverse_tunnel_idle_timeout_seconds,
|
||||
reverse_tunnel_grace_timeout_seconds=settings.reverse_tunnel_grace_timeout_seconds,
|
||||
reverse_tunnel_heartbeat_seconds=settings.reverse_tunnel_heartbeat_seconds,
|
||||
reverse_tunnel_log_path=settings.reverse_tunnel_log_file,
|
||||
vpn_tunnel_log_path=settings.vpn_tunnel_log_file,
|
||||
wireguard_port=settings.wireguard_port,
|
||||
wireguard_engine_virtual_ip=settings.wireguard_engine_virtual_ip,
|
||||
wireguard_peer_network=settings.wireguard_peer_network,
|
||||
wireguard_server_private_key_path=settings.wireguard_server_private_key_path,
|
||||
wireguard_server_public_key_path=settings.wireguard_server_public_key_path,
|
||||
wireguard_acl_allowlist_windows=settings.wireguard_acl_allowlist_windows,
|
||||
wireguard_shell_port=settings.wireguard_shell_port,
|
||||
assembly_cache=None,
|
||||
)
|
||||
|
||||
@@ -249,7 +241,7 @@ def create_app(config: Optional[Mapping[str, Any]] = None) -> Tuple[Flask, Socke
|
||||
private_key_path=Path(context.wireguard_server_private_key_path),
|
||||
public_key_path=Path(context.wireguard_server_public_key_path),
|
||||
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
|
||||
log_path=Path(context.reverse_tunnel_log_path),
|
||||
log_path=Path(context.vpn_tunnel_log_path),
|
||||
)
|
||||
context.wireguard_server_manager = WireGuardServerManager(wg_config)
|
||||
except Exception:
|
||||
@@ -325,7 +317,7 @@ def register_engine_api(app: Flask, *, config: Optional[Mapping[str, Any]] = Non
|
||||
private_key_path=Path(context.wireguard_server_private_key_path),
|
||||
public_key_path=Path(context.wireguard_server_public_key_path),
|
||||
acl_allowlist_windows=tuple(context.wireguard_acl_allowlist_windows),
|
||||
log_path=Path(context.reverse_tunnel_log_path),
|
||||
log_path=Path(context.vpn_tunnel_log_path),
|
||||
)
|
||||
context.wireguard_server_manager = WireGuardServerManager(wg_config)
|
||||
except Exception:
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
# - GET /api/devices/<guid> (Token Authenticated) - Retrieves a single device record by GUID, including summary fields.
|
||||
# - GET /api/device/details/<hostname> (Token Authenticated) - Returns full device details keyed by hostname.
|
||||
# - POST /api/device/description/<hostname> (Token Authenticated) - Updates the human-readable description for a device.
|
||||
# - GET /api/device/vpn_config/<agent_id> (Token Authenticated) - Returns per-device VPN allowed port settings.
|
||||
# - PUT /api/device/vpn_config/<agent_id> (Token Authenticated) - Updates per-device VPN allowed port settings.
|
||||
# - GET /api/device_list_views (Token Authenticated) - Lists saved device table view definitions.
|
||||
# - GET /api/device_list_views/<int:view_id> (Token Authenticated) - Retrieves a specific saved device table view definition.
|
||||
# - POST /api/device_list_views (Token Authenticated) - Creates a custom device list view for the signed-in operator.
|
||||
@@ -426,6 +428,131 @@ class DeviceManagementService:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _parse_ports(self, raw: Any) -> List[int]:
|
||||
ports: List[int] = []
|
||||
if isinstance(raw, str):
|
||||
parts = [part.strip() for part in raw.split(",") if part.strip()]
|
||||
elif isinstance(raw, list):
|
||||
parts = raw
|
||||
else:
|
||||
parts = []
|
||||
for part in parts:
|
||||
try:
|
||||
value = int(part)
|
||||
except Exception:
|
||||
continue
|
||||
if 1 <= value <= 65535:
|
||||
ports.append(value)
|
||||
return list(dict.fromkeys(ports))
|
||||
|
||||
def _default_vpn_ports(self, os_name: Optional[str]) -> List[int]:
|
||||
ports = list(self.adapters.context.wireguard_acl_allowlist_windows or [])
|
||||
os_text = (os_name or "").strip().lower()
|
||||
if os_text and "windows" not in os_text:
|
||||
baseline = {5900, 3478}
|
||||
filtered = [p for p in ports if p in baseline]
|
||||
return filtered or ports
|
||||
return ports
|
||||
|
||||
def get_vpn_config(self, agent_id: str) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
default_ports: List[int] = []
|
||||
shell_port = int(self.adapters.context.wireguard_shell_port)
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
os_name = ""
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
os_name = str(row[0])
|
||||
except Exception:
|
||||
os_name = ""
|
||||
default_ports = self._default_vpn_ports(os_name)
|
||||
cur.execute(
|
||||
"SELECT allowed_ports, updated_at, updated_by FROM device_vpn_config WHERE agent_id=?",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"source": "default",
|
||||
}, 200
|
||||
raw_ports = row[0] or ""
|
||||
ports = []
|
||||
try:
|
||||
ports = json.loads(raw_ports) if raw_ports else []
|
||||
except Exception:
|
||||
ports = []
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports or default_ports,
|
||||
"default_ports": default_ports,
|
||||
"shell_port": shell_port,
|
||||
"updated_at": row[1],
|
||||
"updated_by": row[2],
|
||||
"source": "custom" if ports else "default",
|
||||
}, 200
|
||||
except Exception as exc:
|
||||
self.logger.debug("Failed to load vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_vpn_config(self, agent_id: str, payload: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
|
||||
agent_id = (agent_id or "").strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}, 400
|
||||
ports = self._parse_ports(payload.get("allowed_ports"))
|
||||
if not ports:
|
||||
return {"error": "allowed_ports_required"}, 400
|
||||
user = self._current_user() or {}
|
||||
updated_by = user.get("username") or ""
|
||||
updated_at = datetime.now(timezone.utc).isoformat()
|
||||
try:
|
||||
conn = self._db_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO device_vpn_config(agent_id, allowed_ports, updated_at, updated_by)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(agent_id) DO UPDATE SET
|
||||
allowed_ports=excluded.allowed_ports,
|
||||
updated_at=excluded.updated_at,
|
||||
updated_by=excluded.updated_by
|
||||
""",
|
||||
(agent_id, json.dumps(ports), updated_at, updated_by),
|
||||
)
|
||||
conn.commit()
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"allowed_ports": ports,
|
||||
"updated_at": updated_at,
|
||||
"updated_by": updated_by,
|
||||
"source": "custom",
|
||||
}, 200
|
||||
except Exception:
|
||||
self.logger.debug("Failed to save vpn config", exc_info=True)
|
||||
return {"error": "internal_error"}, 500
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _require_login(self) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
if not self._current_user():
|
||||
return {"error": "unauthorized"}, 401
|
||||
@@ -1793,6 +1920,19 @@ def register_management(app, adapters: "EngineServiceAdapters") -> None:
|
||||
payload, status = service.set_device_description(hostname, description)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device/vpn_config/<agent_id>", methods=["GET", "PUT"])
|
||||
def _vpn_config(agent_id: str):
|
||||
requirement = service._require_login()
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
if request.method == "GET":
|
||||
payload, status = service.get_vpn_config(agent_id)
|
||||
else:
|
||||
body = request.get_json(silent=True) or {}
|
||||
payload, status = service.set_vpn_config(agent_id, body)
|
||||
return jsonify(payload), status
|
||||
|
||||
@blueprint.route("/api/device_list_views", methods=["GET"])
|
||||
def _list_views():
|
||||
requirement = service._require_login()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\API\devices\tunnel.py
|
||||
# Description: Negotiation endpoint for reverse tunnel leases (operator-initiated; dormant until tunnel listener is wired).
|
||||
# Description: WireGuard VPN tunnel API (connect/status/disconnect).
|
||||
#
|
||||
# API Endpoints (if applicable):
|
||||
# - POST /api/tunnel/request (Token Authenticated) - Allocates a reverse tunnel lease for the requested agent/protocol.
|
||||
# - POST /api/tunnel/connect (Token Authenticated) - Issues VPN session material for an agent.
|
||||
# - GET /api/tunnel/status (Token Authenticated) - Returns VPN status for an agent.
|
||||
# - DELETE /api/tunnel/disconnect (Token Authenticated) - Tears down VPN session for an agent.
|
||||
# ======================================================
|
||||
|
||||
"""Reverse tunnel negotiation API (Engine side)."""
|
||||
"""WireGuard VPN tunnel API (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
@@ -15,15 +17,13 @@ from typing import Any, Dict, Optional, Tuple
|
||||
from flask import Blueprint, jsonify, request, session
|
||||
from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer
|
||||
|
||||
from ...WebSocket.Agent.reverse_tunnel_orchestrator import ReverseTunnelService
|
||||
from ...VPN import VpnTunnelService
|
||||
|
||||
if False: # pragma: no cover - import cycle hint for type checkers
|
||||
from .. import EngineServiceAdapters
|
||||
|
||||
|
||||
def _current_user(app) -> Optional[Dict[str, str]]:
|
||||
"""Resolve operator identity from session or signed token."""
|
||||
|
||||
username = session.get("username")
|
||||
role = session.get("role") or "User"
|
||||
if username:
|
||||
@@ -58,18 +58,22 @@ def _require_login(app) -> Optional[Tuple[Dict[str, Any], int]]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> ReverseTunnelService:
|
||||
service = getattr(adapters.context, "reverse_tunnel_service", None) or getattr(adapters, "_reverse_tunnel_service", None)
|
||||
def _get_tunnel_service(adapters: "EngineServiceAdapters") -> VpnTunnelService:
|
||||
service = getattr(adapters.context, "vpn_tunnel_service", None) or getattr(adapters, "_vpn_tunnel_service", None)
|
||||
if service is None:
|
||||
service = ReverseTunnelService(
|
||||
adapters.context,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
manager = getattr(adapters.context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
raise RuntimeError("wireguard_manager_unavailable")
|
||||
service = VpnTunnelService(
|
||||
context=adapters.context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=getattr(adapters.context, "socketio", None),
|
||||
service_log=adapters.service_log,
|
||||
signer=getattr(adapters, "script_signer", None),
|
||||
)
|
||||
service.start()
|
||||
setattr(adapters, "_reverse_tunnel_service", service)
|
||||
setattr(adapters.context, "reverse_tunnel_service", service)
|
||||
setattr(adapters, "_vpn_tunnel_service", service)
|
||||
setattr(adapters.context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
|
||||
@@ -83,14 +87,11 @@ def _normalize_text(value: Any) -> str:
|
||||
|
||||
|
||||
def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
"""Register reverse tunnel negotiation endpoints."""
|
||||
blueprint = Blueprint("vpn_tunnel", __name__)
|
||||
logger = adapters.context.logger.getChild("vpn_tunnel.api")
|
||||
|
||||
blueprint = Blueprint("reverse_tunnel", __name__)
|
||||
service_log = adapters.service_log
|
||||
logger = adapters.context.logger.getChild("tunnel.api")
|
||||
|
||||
@blueprint.route("/api/tunnel/request", methods=["POST"])
|
||||
def request_tunnel():
|
||||
@blueprint.route("/api/tunnel/connect", methods=["POST"])
|
||||
def connect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
@@ -101,69 +102,67 @@ def register_tunnel(app, adapters: "EngineServiceAdapters") -> None:
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
protocol = _normalize_text(body.get("protocol") or "ps").lower() or "ps"
|
||||
domain = _normalize_text(body.get("domain") or protocol).lower() or protocol
|
||||
if protocol == "ps" and domain == "ps":
|
||||
domain = "remote-interactive-shell"
|
||||
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
try:
|
||||
lease = tunnel_service.request_lease(
|
||||
agent_id=agent_id,
|
||||
protocol=protocol,
|
||||
domain=domain,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.connect(agent_id=agent_id, operator_id=operator_id)
|
||||
except RuntimeError as exc:
|
||||
message = str(exc)
|
||||
if message.startswith("domain_limit:"):
|
||||
domain_name = message.split(":", 1)[-1] if ":" in message else domain
|
||||
return jsonify({"error": "domain_limit", "domain": domain_name}), 409
|
||||
if message == "port_pool_exhausted":
|
||||
return jsonify({"error": "port_pool_exhausted"}), 503
|
||||
logger.warning("tunnel lease request failed for agent_id=%s: %s", agent_id, message)
|
||||
return jsonify({"error": "lease_allocation_failed"}), 500
|
||||
logger.warning("vpn connect failed for agent_id=%s: %s", agent_id, exc)
|
||||
return jsonify({"error": "connect_failed"}), 500
|
||||
|
||||
summary = tunnel_service.lease_summary(lease)
|
||||
summary["fixed_port"] = tunnel_service.fixed_port
|
||||
summary["heartbeat_seconds"] = tunnel_service.heartbeat_seconds
|
||||
return jsonify(payload), 200
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease created tunnel_id={lease.tunnel_id} agent_id={lease.agent_id} domain={lease.domain} protocol={lease.protocol}",
|
||||
)
|
||||
return jsonify(summary), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/<tunnel_id>", methods=["DELETE"])
|
||||
def stop_tunnel(tunnel_id: str):
|
||||
@blueprint.route("/api/tunnel/status", methods=["GET"])
|
||||
def tunnel_status():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
tunnel_id_norm = _normalize_text(tunnel_id)
|
||||
if not tunnel_id_norm:
|
||||
return jsonify({"error": "tunnel_id_required"}), 400
|
||||
agent_id = _normalize_text(request.args.get("agent_id") or "")
|
||||
if not agent_id:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
payload = tunnel_service.status(agent_id)
|
||||
if not payload:
|
||||
return jsonify({"status": "down", "agent_id": agent_id}), 200
|
||||
payload["status"] = "up"
|
||||
bump = _normalize_text(request.args.get("bump") or "")
|
||||
if bump:
|
||||
tunnel_service.bump_activity(agent_id)
|
||||
return jsonify(payload), 200
|
||||
|
||||
@blueprint.route("/api/tunnel/connect/status", methods=["GET"])
|
||||
def tunnel_connect_status():
|
||||
return tunnel_status()
|
||||
|
||||
@blueprint.route("/api/tunnel/disconnect", methods=["DELETE"])
|
||||
def disconnect_tunnel():
|
||||
requirement = _require_login(app)
|
||||
if requirement:
|
||||
payload, status = requirement
|
||||
return jsonify(payload), status
|
||||
|
||||
body = request.get_json(silent=True) or {}
|
||||
agent_id = _normalize_text(body.get("agent_id"))
|
||||
tunnel_id = _normalize_text(body.get("tunnel_id"))
|
||||
reason = _normalize_text(body.get("reason") or "operator_stop")
|
||||
|
||||
tunnel_service = _get_tunnel_service(adapters)
|
||||
stopped = False
|
||||
try:
|
||||
stopped = tunnel_service.stop_tunnel(tunnel_id_norm, reason=reason)
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("stop_tunnel failed tunnel_id=%s: %s", tunnel_id_norm, exc, exc_info=True)
|
||||
if tunnel_id:
|
||||
stopped = tunnel_service.disconnect_by_tunnel(tunnel_id, reason=reason)
|
||||
elif agent_id:
|
||||
stopped = tunnel_service.disconnect(agent_id, reason=reason)
|
||||
else:
|
||||
return jsonify({"error": "agent_id_required"}), 400
|
||||
|
||||
if not stopped:
|
||||
return jsonify({"error": "not_found"}), 404
|
||||
|
||||
service_log(
|
||||
"reverse_tunnel",
|
||||
f"lease stopped tunnel_id={tunnel_id_norm} reason={reason or '-'}",
|
||||
)
|
||||
return jsonify({"status": "stopped", "tunnel_id": tunnel_id_norm}), 200
|
||||
return jsonify({"status": "stopped", "reason": reason}), 200
|
||||
|
||||
app.register_blueprint(blueprint)
|
||||
|
||||
@@ -8,4 +8,4 @@
|
||||
"""VPN service helpers for the Engine runtime."""
|
||||
|
||||
from .wireguard_server import WireGuardServerConfig, WireGuardServerManager # noqa: F401
|
||||
|
||||
from .vpn_tunnel_service import VpnTunnelService # noqa: F401
|
||||
|
||||
473
Data/Engine/services/VPN/vpn_tunnel_service.py
Normal file
473
Data/Engine/services/VPN/vpn_tunnel_service.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\VPN\vpn_tunnel_service.py
|
||||
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from .wireguard_server import WireGuardServerManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class VpnSession:
|
||||
tunnel_id: str
|
||||
agent_id: str
|
||||
virtual_ip: str
|
||||
token: Dict[str, Any]
|
||||
client_public_key: str
|
||||
client_private_key: str
|
||||
allowed_ports: Tuple[int, ...]
|
||||
created_at: float
|
||||
expires_at: float
|
||||
last_activity: float
|
||||
operator_ids: set[str] = field(default_factory=set)
|
||||
firewall_rules: List[str] = field(default_factory=list)
|
||||
activity_id: Optional[int] = None
|
||||
hostname: Optional[str] = None
|
||||
|
||||
|
||||
class VpnTunnelService:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
context: Any,
|
||||
wireguard_manager: WireGuardServerManager,
|
||||
db_conn_factory,
|
||||
socketio,
|
||||
service_log,
|
||||
signer: Optional[Any] = None,
|
||||
idle_seconds: int = 900,
|
||||
) -> None:
|
||||
self.context = context
|
||||
self.wg = wireguard_manager
|
||||
self.db_conn_factory = db_conn_factory
|
||||
self.socketio = socketio
|
||||
self.service_log = service_log
|
||||
self.signer = signer
|
||||
self.logger = context.logger.getChild("vpn_tunnel")
|
||||
self.activity_logger = self.wg.logger.getChild("device_activity")
|
||||
self.idle_seconds = max(60, int(idle_seconds))
|
||||
self._lock = threading.Lock()
|
||||
self._sessions_by_agent: Dict[str, VpnSession] = {}
|
||||
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
|
||||
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
|
||||
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
|
||||
def _idle_loop(self) -> None:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
now = time.time()
|
||||
expired: List[VpnSession] = []
|
||||
with self._lock:
|
||||
for session in list(self._sessions_by_agent.values()):
|
||||
if session.last_activity + self.idle_seconds <= now:
|
||||
expired.append(session)
|
||||
for session in expired:
|
||||
self.disconnect(session.agent_id, reason="idle_timeout")
|
||||
|
||||
def _allocate_virtual_ip(self, agent_id: str) -> str:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
return existing.virtual_ip
|
||||
|
||||
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
|
||||
for host in self._peer_network.hosts():
|
||||
if host == self._engine_ip.ip:
|
||||
continue
|
||||
candidate = f"{host}/32"
|
||||
if candidate not in used:
|
||||
return candidate
|
||||
raise RuntimeError("vpn_ip_pool_exhausted")
|
||||
|
||||
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
|
||||
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
|
||||
try:
|
||||
conn = self.db_conn_factory()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
os_name = str(row[0]).lower() if row and row[0] else ""
|
||||
except Exception:
|
||||
os_name = ""
|
||||
if os_name and "windows" not in os_name:
|
||||
baseline = {5900, 3478}
|
||||
filtered = [p for p in default if p in baseline]
|
||||
if filtered:
|
||||
default = tuple(filtered)
|
||||
cur.execute(
|
||||
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
|
||||
(agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
return default
|
||||
raw = row[0] or ""
|
||||
ports = json.loads(raw) if raw else []
|
||||
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
|
||||
ports = [p for p in ports if 1 <= p <= 65535]
|
||||
return tuple(dict.fromkeys(ports)) or default
|
||||
except Exception:
|
||||
return default
|
||||
finally:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _generate_client_keys(self) -> Tuple[str, str]:
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
|
||||
key = x25519.X25519PrivateKey.generate()
|
||||
priv = base64.b64encode(
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PrivateFormat.Raw,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
).decode("ascii").strip()
|
||||
pub = base64.b64encode(
|
||||
key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.Raw,
|
||||
format=serialization.PublicFormat.Raw,
|
||||
)
|
||||
).decode("ascii").strip()
|
||||
return priv, pub
|
||||
|
||||
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"agent_id": agent_id,
|
||||
"tunnel_id": tunnel_id,
|
||||
"port": self.context.wireguard_port,
|
||||
"expires_at": expires_at,
|
||||
"issued_at": time.time(),
|
||||
}
|
||||
if not self.signer:
|
||||
return dict(payload)
|
||||
|
||||
token = dict(payload)
|
||||
try:
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
signature = self.signer.sign(payload_bytes)
|
||||
token["signature"] = base64.b64encode(signature).decode("ascii")
|
||||
if hasattr(self.signer, "public_base64_spki"):
|
||||
token["signing_key"] = self.signer.public_base64_spki()
|
||||
token["sig_alg"] = "ed25519"
|
||||
except Exception:
|
||||
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
|
||||
return token
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
for session in self._sessions_by_agent.values():
|
||||
peer = self.wg.build_peer_profile(
|
||||
session.agent_id,
|
||||
session.virtual_ip,
|
||||
allowed_ports=session.allowed_ports,
|
||||
)
|
||||
peer = dict(peer)
|
||||
peer["public_key"] = session.client_public_key
|
||||
peers.append(peer)
|
||||
if not peers:
|
||||
self.wg.stop_listener()
|
||||
return
|
||||
self.wg.start_listener(peers)
|
||||
|
||||
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
if existing:
|
||||
if operator_id:
|
||||
existing.operator_ids.add(operator_id)
|
||||
existing.last_activity = now
|
||||
return self._session_payload(existing)
|
||||
|
||||
tunnel_id = uuid.uuid4().hex
|
||||
virtual_ip = self._allocate_virtual_ip(agent_id)
|
||||
allowed_ports = self._load_allowed_ports(agent_id)
|
||||
client_private, client_public = self._generate_client_keys()
|
||||
token = self._issue_token(agent_id, tunnel_id, now + 300)
|
||||
self.wg.require_orchestration_token(token)
|
||||
|
||||
session = VpnSession(
|
||||
tunnel_id=tunnel_id,
|
||||
agent_id=agent_id,
|
||||
virtual_ip=virtual_ip,
|
||||
token=token,
|
||||
client_public_key=client_public,
|
||||
client_private_key=client_private,
|
||||
allowed_ports=allowed_ports,
|
||||
created_at=now,
|
||||
expires_at=now + 300,
|
||||
last_activity=now,
|
||||
)
|
||||
if operator_id:
|
||||
session.operator_ids.add(operator_id)
|
||||
self._sessions_by_agent[agent_id] = session
|
||||
self._sessions_by_tunnel[tunnel_id] = session
|
||||
|
||||
try:
|
||||
self._refresh_listener()
|
||||
|
||||
peer = self.wg.build_peer_profile(
|
||||
agent_id,
|
||||
virtual_ip,
|
||||
allowed_ports=allowed_ports,
|
||||
)
|
||||
rule_names = self.wg.apply_firewall_rules(peer)
|
||||
session.firewall_rules = rule_names
|
||||
except Exception:
|
||||
with self._lock:
|
||||
self._sessions_by_agent.pop(agent_id, None)
|
||||
self._sessions_by_tunnel.pop(tunnel_id, None)
|
||||
try:
|
||||
self._refresh_listener()
|
||||
except Exception:
|
||||
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
|
||||
raise
|
||||
|
||||
payload = self._session_payload(session)
|
||||
self._emit_start(payload)
|
||||
self._log_device_activity(session, event="start")
|
||||
return payload
|
||||
|
||||
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return None
|
||||
return self._session_payload(session, include_token=False)
|
||||
|
||||
def bump_activity(self, agent_id: str) -> None:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
return
|
||||
session.last_activity = time.time()
|
||||
try:
|
||||
if self.socketio:
|
||||
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
|
||||
|
||||
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
if not session:
|
||||
return False
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
|
||||
try:
|
||||
self.wg.remove_firewall_rules(session.firewall_rules)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
|
||||
|
||||
self._refresh_listener()
|
||||
self._emit_stop(session, reason)
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
|
||||
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_tunnel.get(tunnel_id)
|
||||
if not session:
|
||||
return False
|
||||
return self.disconnect(session.agent_id, reason=reason)
|
||||
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
try:
|
||||
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
|
||||
|
||||
def _emit_stop(self, session: VpnSession, reason: str) -> None:
|
||||
if not self.socketio:
|
||||
return
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"vpn_tunnel_stop",
|
||||
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
|
||||
namespace="/",
|
||||
)
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
|
||||
|
||||
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
|
||||
if self.db_conn_factory is None:
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
|
||||
event,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
)
|
||||
return
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = self.db_conn_factory()
|
||||
cur = conn.cursor()
|
||||
hostname = session.hostname
|
||||
if not hostname:
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
|
||||
(session.agent_id,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
hostname = str(row[0]).strip()
|
||||
session.hostname = hostname
|
||||
except Exception:
|
||||
hostname = None
|
||||
|
||||
if not hostname:
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
|
||||
event,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
)
|
||||
return
|
||||
|
||||
now_ts = int(time.time())
|
||||
script_name = "Reverse VPN Tunnel (WireGuard)"
|
||||
|
||||
if event == "start":
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
|
||||
VALUES(?,?,?,?,?,?,?,?)
|
||||
""",
|
||||
(
|
||||
hostname,
|
||||
session.tunnel_id,
|
||||
script_name,
|
||||
"vpn_tunnel",
|
||||
now_ts,
|
||||
"Running",
|
||||
"",
|
||||
"",
|
||||
),
|
||||
)
|
||||
session.activity_id = cur.lastrowid
|
||||
conn.commit()
|
||||
if self.socketio:
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"device_activity_changed",
|
||||
{
|
||||
"hostname": hostname,
|
||||
"activity_id": session.activity_id,
|
||||
"change": "created",
|
||||
"source": "vpn_tunnel",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self.activity_logger.info(
|
||||
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
|
||||
hostname,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
session.activity_id or "-",
|
||||
)
|
||||
return
|
||||
|
||||
if session.activity_id:
|
||||
status = "Completed" if event == "stop" else "Closed"
|
||||
cur.execute(
|
||||
"""
|
||||
UPDATE activity_history
|
||||
SET status=?,
|
||||
stderr=COALESCE(stderr, '') || ?
|
||||
WHERE id=?
|
||||
""",
|
||||
(
|
||||
status,
|
||||
f"\nreason: {reason}" if reason else "",
|
||||
session.activity_id,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
if self.socketio:
|
||||
try:
|
||||
self.socketio.emit(
|
||||
"device_activity_changed",
|
||||
{
|
||||
"hostname": hostname,
|
||||
"activity_id": session.activity_id,
|
||||
"change": "updated",
|
||||
"source": "vpn_tunnel",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.activity_logger.info(
|
||||
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
|
||||
event,
|
||||
hostname,
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
||||
reason or "-",
|
||||
session.activity_id or "-",
|
||||
)
|
||||
except Exception:
|
||||
self.activity_logger.debug(
|
||||
"device_activity logging failed for tunnel_id=%s",
|
||||
session.tunnel_id,
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"tunnel_id": session.tunnel_id,
|
||||
"agent_id": session.agent_id,
|
||||
"virtual_ip": session.virtual_ip,
|
||||
"engine_virtual_ip": str(self._engine_ip.ip),
|
||||
"allowed_ips": f"{self._engine_ip.ip}/32",
|
||||
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"allowed_ports": list(session.allowed_ports),
|
||||
"connected_operators": len([o for o in session.operator_ids if o]),
|
||||
}
|
||||
if include_token:
|
||||
payload["token"] = session.token
|
||||
return payload
|
||||
@@ -70,7 +70,7 @@ class WireGuardServerManager:
|
||||
self.logger = _build_logger(config.log_path)
|
||||
self._ensure_cert_dir()
|
||||
self.server_private_key, self.server_public_key = self._ensure_server_keys()
|
||||
self._service_name = "BorealisWireGuard"
|
||||
self._service_name = "borealis-wg"
|
||||
self._temp_dir = Path(tempfile.gettempdir()) / "borealis-wg-engine"
|
||||
|
||||
def _ensure_cert_dir(self) -> None:
|
||||
@@ -157,7 +157,7 @@ class WireGuardServerManager:
|
||||
if not token:
|
||||
raise ValueError("Missing orchestration token for WireGuard peer")
|
||||
|
||||
required_fields = ("agent_id", "tunnel_id", "expires_at")
|
||||
required_fields = ("agent_id", "tunnel_id", "expires_at", "port")
|
||||
missing = [field for field in required_fields if field not in token or token[field] in (None, "")]
|
||||
if missing:
|
||||
raise ValueError(f"Invalid orchestration token; missing {', '.join(missing)}")
|
||||
@@ -167,6 +167,13 @@ class WireGuardServerManager:
|
||||
except Exception:
|
||||
raise ValueError("Invalid orchestration token expiry")
|
||||
|
||||
try:
|
||||
port = int(token["port"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid orchestration token port")
|
||||
if port != int(self.config.port):
|
||||
raise ValueError("Orchestration token port mismatch")
|
||||
|
||||
now = time.time()
|
||||
if expires_at <= now:
|
||||
raise ValueError("Orchestration token expired")
|
||||
@@ -253,12 +260,14 @@ class WireGuardServerManager:
|
||||
"host_only": True,
|
||||
}
|
||||
|
||||
def apply_firewall_rules(self, peer: Mapping[str, object]) -> None:
|
||||
def apply_firewall_rules(self, peer: Mapping[str, object]) -> List[str]:
|
||||
"""Apply outbound firewall allow rules for the agent's virtual IP/ports (Windows netsh)."""
|
||||
|
||||
rules = self.build_firewall_rules(peer)
|
||||
rule_names: List[str] = []
|
||||
for idx, rule in enumerate(rules):
|
||||
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
|
||||
protocol = str(rule.get("protocol") or "TCP").upper()
|
||||
args = [
|
||||
"netsh",
|
||||
"advfirewall",
|
||||
@@ -269,7 +278,7 @@ class WireGuardServerManager:
|
||||
"dir=out",
|
||||
"action=allow",
|
||||
f"remoteip={rule.get('remote_address','')}",
|
||||
f"protocol=TCP",
|
||||
f"protocol={protocol}",
|
||||
f"localport={rule.get('local_port','')}",
|
||||
]
|
||||
code, out, err = self._run_command(args)
|
||||
@@ -277,6 +286,19 @@ class WireGuardServerManager:
|
||||
self.logger.warning("Failed to apply firewall rule %s code=%s err=%s", name, code, err)
|
||||
else:
|
||||
self.logger.info("Applied firewall rule %s", name)
|
||||
rule_names.append(name)
|
||||
return rule_names
|
||||
|
||||
def remove_firewall_rules(self, rule_names: Sequence[str]) -> None:
|
||||
for name in rule_names:
|
||||
if not name:
|
||||
continue
|
||||
args = ["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.warning("Failed to remove firewall rule %s code=%s err=%s", name, code, err)
|
||||
else:
|
||||
self.logger.info("Removed firewall rule %s", name)
|
||||
|
||||
def start_listener(self, peers: Sequence[Mapping[str, object]]) -> None:
|
||||
"""Render a temporary WireGuard config and start the service."""
|
||||
@@ -291,6 +313,9 @@ class WireGuardServerManager:
|
||||
config_path.write_text(rendered, encoding="utf-8")
|
||||
self.logger.info("Rendered WireGuard config to %s", config_path)
|
||||
|
||||
# Ensure old service is removed before re-installing.
|
||||
self.stop_listener()
|
||||
|
||||
args = ["wireguard.exe", "/installtunnelservice", str(config_path)]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
@@ -301,7 +326,7 @@ class WireGuardServerManager:
|
||||
def stop_listener(self) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
|
||||
args = ["wireguard.exe", "/uninstalltunnelservice", "borealis-wg"]
|
||||
args = ["wireguard.exe", "/uninstalltunnelservice", self._service_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
|
||||
@@ -323,15 +348,17 @@ class WireGuardServerManager:
|
||||
port_list = []
|
||||
|
||||
for port in port_list:
|
||||
rules.append(
|
||||
{
|
||||
"direction": "outbound",
|
||||
"remote_address": ip,
|
||||
"local_port": port,
|
||||
"action": "allow",
|
||||
"description": f"WireGuard engine->agent allow port {port}",
|
||||
}
|
||||
)
|
||||
for protocol in ("TCP", "UDP"):
|
||||
rules.append(
|
||||
{
|
||||
"direction": "outbound",
|
||||
"remote_address": ip,
|
||||
"local_port": port,
|
||||
"protocol": protocol,
|
||||
"action": "allow",
|
||||
"description": f"WireGuard engine->agent allow port {port}/{protocol}",
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Prepared firewall rule plan for agent=%s rules=%s",
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Namespace package for reverse tunnel domain handlers (Engine side)."""
|
||||
|
||||
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Placeholder Bash channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class BashChannelServer:
|
||||
"""Stub Bash handler until the agent-side channel is implemented."""
|
||||
|
||||
protocol_name = "bash"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"bash.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
# No-op placeholder; output collection for future Bash support.
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"bash channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
# Placeholder: no agent-side Bash yet.
|
||||
self.logger.info("bash placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
# Placeholder: not implemented.
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["BashChannelServer"]
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Engine-side PowerShell tunnel channel helper (remote interactive shell domain)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
# Mirror framing constants to avoid circular imports.
|
||||
MSG_CHANNEL_OPEN = 0x03
|
||||
MSG_CHANNEL_ACK = 0x04
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_OK = 0
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class PowershellChannelServer:
|
||||
"""Coordinate PowerShell channel frames over a TunnelBridge."""
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output: Deque[str] = deque()
|
||||
self._close_reason: Optional[str] = None
|
||||
self._close_code: Optional[int] = None
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
# ------------------------------------------------------------------ Agent frame handling
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
if frame.channel_id != self.channel_id:
|
||||
return
|
||||
if frame.msg_type == MSG_CHANNEL_ACK:
|
||||
self._ack_received = True
|
||||
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
return
|
||||
if frame.msg_type == MSG_DATA:
|
||||
try:
|
||||
text = frame.payload.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
text = ""
|
||||
if text:
|
||||
self._append_output(text)
|
||||
return
|
||||
if frame.msg_type == MSG_CLOSE:
|
||||
try:
|
||||
payload = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
payload = {}
|
||||
self._closed = True
|
||||
self._close_code = payload.get("code") if isinstance(payload, dict) else None
|
||||
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
|
||||
self.logger.info(
|
||||
"ps channel closed tunnel_id=%s code=%s reason=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self._close_code,
|
||||
self._close_reason or "-",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ Operator actions
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
if self._open_sent:
|
||||
return
|
||||
payload = json.dumps(
|
||||
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
frame = self._frame_cls(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = data.encode("utf-8", errors="replace")
|
||||
frame = self._frame_cls(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
|
||||
frame = self._frame_cls(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
else:
|
||||
frame = self._frame_cls(
|
||||
msg_type=MSG_CLOSE,
|
||||
channel_id=self.channel_id,
|
||||
payload=json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8"),
|
||||
)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
# ------------------------------------------------------------------ Output polling
|
||||
def drain_output(self) -> List[str]:
|
||||
items: List[str] = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def _append_output(self, text: str) -> None:
|
||||
self._output.append(text)
|
||||
# Cap buffer to avoid unbounded memory growth.
|
||||
while len(self._output) > 500:
|
||||
self._output.popleft()
|
||||
|
||||
# ------------------------------------------------------------------ Status helpers
|
||||
def status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": self._close_reason,
|
||||
"close_code": self._close_code,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["PowershellChannelServer"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote interactive shell tunnels (Engine side)."""
|
||||
|
||||
from .Powershell import PowershellChannelServer
|
||||
from .Bash import BashChannelServer
|
||||
|
||||
__all__ = ["PowershellChannelServer", "BashChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote interactive shells (PowerShell/Bash)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder SSH channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class SSHChannelServer:
|
||||
protocol_name = "ssh"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"ssh.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"ssh channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("ssh placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["SSHChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder WinRM channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class WinRMChannelServer:
|
||||
protocol_name = "winrm"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"winrm.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"winrm channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("winrm placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["WinRMChannelServer"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote management tunnels (Engine side)."""
|
||||
|
||||
from .SSH import SSHChannelServer
|
||||
from .WinRM import WinRMChannelServer
|
||||
|
||||
__all__ = ["SSHChannelServer", "WinRMChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote management tunnels (SSH/WinRM)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder RDP channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RDPChannelServer:
|
||||
protocol_name = "rdp"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"rdp.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"rdp channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("rdp placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["RDPChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder VNC channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class VNCChannelServer:
|
||||
protocol_name = "vnc"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"vnc.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"vnc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("vnc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["VNCChannelServer"]
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Placeholder WebRTC channel server (Engine side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
class WebRTCChannelServer:
|
||||
protocol_name = "webrtc"
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1, frame_cls=None, close_frame_fn=None):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"webrtc.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output = deque()
|
||||
self._frame_cls = frame_cls
|
||||
self._close_frame_fn = close_frame_fn
|
||||
|
||||
def handle_agent_frame(self, frame) -> None:
|
||||
try:
|
||||
if frame.msg_type == 0x04: # MSG_CHANNEL_ACK
|
||||
self._ack_received = True
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"webrtc channel placeholder open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
self.logger.info("webrtc placeholder send_input ignored tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
return
|
||||
|
||||
def close(self, code: int = 6, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if callable(self._close_frame_fn):
|
||||
try:
|
||||
frame = self._close_frame_fn(self.channel_id, code, reason)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def drain_output(self):
|
||||
items = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def status(self):
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": None,
|
||||
"close_code": None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["WebRTCChannelServer"]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Protocol handlers for remote video tunnels (Engine side)."""
|
||||
|
||||
from .WebRTC import WebRTCChannelServer
|
||||
from .RDP import RDPChannelServer
|
||||
from .VNC import VNCChannelServer
|
||||
|
||||
__all__ = ["WebRTCChannelServer", "RDPChannelServer", "VNCChannelServer"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Domain handlers for remote video/desktop tunnels (RDP/VNC/WebRTC)."""
|
||||
|
||||
__all__ = ["Protocols"]
|
||||
@@ -1,10 +0,0 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\Agent\__init__.py
|
||||
# Description: Package marker for Agent-facing WebSocket services (reverse tunnel scaffolding).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""Agent-facing WebSocket services for the Engine runtime."""
|
||||
|
||||
__all__ = []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\__init__.py
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and realtime notifications.
|
||||
# Description: Socket.IO handlers for Engine runtime quick job updates and VPN shell bridging.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
@@ -8,24 +8,20 @@
|
||||
"""WebSocket service registration for the Borealis Engine runtime."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import sqlite3
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from flask import session, request
|
||||
from flask import request
|
||||
from flask_socketio import SocketIO
|
||||
|
||||
from ...database import initialise_engine_database
|
||||
from ...security import signing
|
||||
from ...server import EngineContext
|
||||
from .Agent.reverse_tunnel_orchestrator import (
|
||||
ReverseTunnelService,
|
||||
TunnelBridge,
|
||||
decode_frame,
|
||||
TunnelFrame,
|
||||
)
|
||||
from ..VPN import VpnTunnelService
|
||||
from .vpn_shell import VpnShellBridge
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
@@ -70,20 +66,31 @@ class EngineRealtimeAdapters:
|
||||
def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
"""Register Socket.IO event handlers for the Engine runtime."""
|
||||
|
||||
from ..API import _make_db_conn_factory, _make_service_logger # Local import to avoid circular import at module load
|
||||
|
||||
adapters = EngineRealtimeAdapters(context)
|
||||
logger = context.logger.getChild("realtime.quick_jobs")
|
||||
tunnel_service = getattr(context, "reverse_tunnel_service", None)
|
||||
if tunnel_service is None:
|
||||
tunnel_service = ReverseTunnelService(
|
||||
context,
|
||||
signer=None,
|
||||
shell_bridge = VpnShellBridge(socket_server, context)
|
||||
|
||||
def _get_tunnel_service() -> Optional[VpnTunnelService]:
|
||||
service = getattr(context, "vpn_tunnel_service", None)
|
||||
if service is not None:
|
||||
return service
|
||||
manager = getattr(context, "wireguard_server_manager", None)
|
||||
if manager is None:
|
||||
return None
|
||||
try:
|
||||
signer = signing.load_signer()
|
||||
except Exception:
|
||||
signer = None
|
||||
service = VpnTunnelService(
|
||||
context=context,
|
||||
wireguard_manager=manager,
|
||||
db_conn_factory=adapters.db_conn_factory,
|
||||
socketio=socket_server,
|
||||
service_log=adapters.service_log,
|
||||
signer=signer,
|
||||
)
|
||||
tunnel_service.start()
|
||||
setattr(context, "reverse_tunnel_service", tunnel_service)
|
||||
setattr(context, "vpn_tunnel_service", service)
|
||||
return service
|
||||
|
||||
@socket_server.on("quick_job_result")
|
||||
def _handle_quick_job_result(data: Any) -> None:
|
||||
@@ -246,252 +253,45 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
exc,
|
||||
)
|
||||
|
||||
@socket_server.on("tunnel_bridge_attach")
|
||||
def _tunnel_bridge_attach(data: Any) -> Any:
|
||||
"""Placeholder operator bridge attach handler (no data channel yet)."""
|
||||
@socket_server.on("vpn_shell_open")
|
||||
def _vpn_shell_open(data: Any) -> Dict[str, Any]:
|
||||
agent_id = ""
|
||||
if isinstance(data, dict):
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
elif isinstance(data, str):
|
||||
agent_id = data.strip()
|
||||
if not agent_id:
|
||||
return {"error": "agent_id_required"}
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
service = _get_tunnel_service()
|
||||
if service is None:
|
||||
return {"error": "vpn_service_unavailable"}
|
||||
if not service.status(agent_id):
|
||||
return {"error": "tunnel_down"}
|
||||
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
operator_id = str(data.get("operator_id") or "").strip() or None
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
session = shell_bridge.open_session(request.sid, agent_id)
|
||||
if session is None:
|
||||
return {"error": "shell_connect_failed"}
|
||||
service.bump_activity(agent_id)
|
||||
return {"status": "ok"}
|
||||
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except ValueError as exc:
|
||||
return {"error": str(exc)}
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("tunnel_bridge_attach failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "bridge_attach_failed"}
|
||||
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "operator_id": operator_id or "-"}
|
||||
|
||||
def _encode_frame(frame: TunnelFrame) -> str:
|
||||
return base64.b64encode(frame.encode()).decode("ascii")
|
||||
|
||||
def _decode_frame_payload(raw: Any) -> TunnelFrame:
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
raw_bytes = base64.b64decode(raw)
|
||||
except Exception:
|
||||
raise ValueError("invalid_frame")
|
||||
elif isinstance(raw, (bytes, bytearray)):
|
||||
raw_bytes = bytes(raw)
|
||||
@socket_server.on("vpn_shell_send")
|
||||
def _vpn_shell_send(data: Any) -> Dict[str, Any]:
|
||||
payload = None
|
||||
if isinstance(data, dict):
|
||||
payload = data.get("data")
|
||||
else:
|
||||
raise ValueError("invalid_frame")
|
||||
return decode_frame(raw_bytes)
|
||||
|
||||
@socket_server.on("tunnel_operator_send")
|
||||
def _tunnel_operator_send(data: Any) -> Any:
|
||||
"""Operator -> agent frame enqueue (placeholder queue)."""
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
frame_raw = data.get("frame")
|
||||
if not tunnel_id or frame_raw is None:
|
||||
return {"error": "tunnel_id_and_frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception as exc:
|
||||
return {"error": str(exc)}
|
||||
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("tunnel_operator_poll")
|
||||
def _tunnel_operator_poll(data: Any) -> Any:
|
||||
"""Operator polls queued frames from agent."""
|
||||
|
||||
tunnel_id = ""
|
||||
if isinstance(data, dict):
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge: Optional[TunnelBridge] = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
# WebUI operator bridge namespace for browser clients
|
||||
tunnel_namespace = "/tunnel"
|
||||
_operator_sessions: Dict[str, str] = {}
|
||||
|
||||
def _current_operator() -> Optional[str]:
|
||||
username = session.get("username")
|
||||
if username:
|
||||
return str(username)
|
||||
auth_header = (request.headers.get("Authorization") or "").strip()
|
||||
token = None
|
||||
if auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
if not token:
|
||||
token = request.cookies.get("borealis_auth")
|
||||
return token or None
|
||||
|
||||
@socket_server.on("join", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_join(data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
operator_id = _current_operator()
|
||||
if not operator_id:
|
||||
return {"error": "unauthorized"}
|
||||
tunnel_id = str(data.get("tunnel_id") or "").strip()
|
||||
if not tunnel_id:
|
||||
return {"error": "tunnel_id_required"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
try:
|
||||
tunnel_service.operator_attach(tunnel_id, operator_id)
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_join failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "attach_failed"}
|
||||
sid = request.sid
|
||||
_operator_sessions[sid] = tunnel_id
|
||||
return {"status": "ok", "tunnel_id": tunnel_id}
|
||||
|
||||
@socket_server.on("send", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_send(data: Any) -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
if not isinstance(data, dict):
|
||||
return {"error": "invalid_payload"}
|
||||
frame_raw = data.get("frame")
|
||||
if frame_raw is None:
|
||||
return {"error": "frame_required"}
|
||||
try:
|
||||
frame = _decode_frame_payload(frame_raw)
|
||||
except Exception:
|
||||
return {"error": "invalid_frame"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
bridge.operator_to_agent(frame)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("poll", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_poll() -> Any:
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return {"error": "not_joined"}
|
||||
bridge = tunnel_service.get_bridge(tunnel_id)
|
||||
if bridge is None:
|
||||
return {"error": "unknown_tunnel"}
|
||||
frames = []
|
||||
while True:
|
||||
frame = bridge.next_for_operator()
|
||||
if frame is None:
|
||||
break
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
def _require_ps_server():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return None, None, {"error": "not_joined"}
|
||||
server = tunnel_service.ensure_protocol_server(tunnel_id)
|
||||
if server is None or not hasattr(server, "open_channel"):
|
||||
return None, tunnel_id, {"error": "ps_unsupported"}
|
||||
return server, tunnel_id, None
|
||||
|
||||
@socket_server.on("ps_open", namespace=tunnel_namespace)
|
||||
def _ws_ps_open(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = 120
|
||||
rows = 32
|
||||
if isinstance(data, dict):
|
||||
try:
|
||||
cols = int(data.get("cols", cols))
|
||||
rows = int(data.get("rows", rows))
|
||||
except Exception:
|
||||
pass
|
||||
cols = max(20, min(cols, 300))
|
||||
rows = max(10, min(rows, 200))
|
||||
try:
|
||||
server.open_channel(cols=cols, rows=rows)
|
||||
except Exception as exc:
|
||||
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_open_failed"}
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
|
||||
|
||||
@socket_server.on("ps_send", namespace=tunnel_namespace)
|
||||
def _ws_ps_send(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
if data is None:
|
||||
payload = data
|
||||
if payload is None:
|
||||
return {"error": "payload_required"}
|
||||
text = data
|
||||
if isinstance(data, dict):
|
||||
text = data.get("data")
|
||||
if text is None:
|
||||
return {"error": "payload_required"}
|
||||
try:
|
||||
server.send_input(str(text))
|
||||
except Exception as exc:
|
||||
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_send_failed"}
|
||||
shell_bridge.send(request.sid, str(payload))
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_resize", namespace=tunnel_namespace)
|
||||
def _ws_ps_resize(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = None
|
||||
rows = None
|
||||
if isinstance(data, dict):
|
||||
cols = data.get("cols")
|
||||
rows = data.get("rows")
|
||||
try:
|
||||
cols_int = int(cols) if cols is not None else 120
|
||||
rows_int = int(rows) if rows is not None else 32
|
||||
cols_int = max(20, min(cols_int, 300))
|
||||
rows_int = max(10, min(rows_int, 200))
|
||||
server.send_resize(cols_int, rows_int)
|
||||
return {"status": "ok", "cols": cols_int, "rows": rows_int}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_resize_failed"}
|
||||
@socket_server.on("vpn_shell_close")
|
||||
def _vpn_shell_close() -> Dict[str, Any]:
|
||||
shell_bridge.close(request.sid)
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_poll", namespace=tunnel_namespace)
|
||||
def _ws_ps_poll(data: Any = None) -> Any: # data is ignored; socketio passes it even when unused
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
try:
|
||||
output = server.drain_output()
|
||||
status = server.status()
|
||||
return {"output": output, "status": status}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_poll_failed"}
|
||||
|
||||
@socket_server.on("disconnect", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_disconnect():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.pop(sid, None)
|
||||
if tunnel_id and tunnel_id not in _operator_sessions.values():
|
||||
try:
|
||||
tunnel_service.stop_tunnel(tunnel_id, reason="operator_socket_disconnect")
|
||||
except Exception as exc:
|
||||
logger.debug("ws_tunnel_disconnect stop_tunnel failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
@socket_server.on("disconnect")
|
||||
def _ws_disconnect() -> None:
|
||||
shell_bridge.close(request.sid)
|
||||
|
||||
127
Data/Engine/services/WebSocket/vpn_shell.py
Normal file
127
Data/Engine/services/WebSocket/vpn_shell.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\WebSocket\vpn_shell.py
|
||||
# Description: Socket.IO handlers bridging UI shell to agent TCP server over WireGuard.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""WireGuard VPN PowerShell bridge (Engine side)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
def _b64encode(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode("ascii").strip()
|
||||
|
||||
|
||||
def _b64decode(value: str) -> bytes:
|
||||
return base64.b64decode(value.encode("ascii"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShellSession:
|
||||
sid: str
|
||||
agent_id: str
|
||||
socketio: Any
|
||||
tcp: socket.socket
|
||||
_reader: Optional[threading.Thread] = None
|
||||
|
||||
def start_reader(self) -> None:
|
||||
t = threading.Thread(target=self._read_loop, daemon=True)
|
||||
t.start()
|
||||
self._reader = t
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
buffer = b""
|
||||
try:
|
||||
while True:
|
||||
data = self.tcp.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(line.decode("utf-8"))
|
||||
except Exception:
|
||||
continue
|
||||
if msg.get("type") == "stdout":
|
||||
payload = msg.get("data") or ""
|
||||
try:
|
||||
decoded = _b64decode(str(payload)).decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
decoded = ""
|
||||
self.socketio.emit("vpn_shell_output", {"data": decoded}, to=self.sid)
|
||||
finally:
|
||||
self.socketio.emit("vpn_shell_closed", {"agent_id": self.agent_id}, to=self.sid)
|
||||
try:
|
||||
self.tcp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send(self, payload: str) -> None:
|
||||
data = json.dumps({"type": "stdin", "data": _b64encode(payload.encode("utf-8"))})
|
||||
self.tcp.sendall(data.encode("utf-8") + b"\n")
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
data = json.dumps({"type": "close"})
|
||||
self.tcp.sendall(data.encode("utf-8") + b"\n")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.tcp.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class VpnShellBridge:
|
||||
def __init__(self, socketio, context) -> None:
|
||||
self.socketio = socketio
|
||||
self.context = context
|
||||
self._sessions: Dict[str, ShellSession] = {}
|
||||
self.logger = context.logger.getChild("vpn_shell")
|
||||
|
||||
def open_session(self, sid: str, agent_id: str) -> Optional[ShellSession]:
|
||||
service = getattr(self.context, "vpn_tunnel_service", None)
|
||||
if service is None:
|
||||
return None
|
||||
status = service.status(agent_id)
|
||||
if not status:
|
||||
return None
|
||||
host = str(status.get("virtual_ip") or "").split("/")[0]
|
||||
port = int(self.context.wireguard_shell_port)
|
||||
try:
|
||||
tcp = socket.create_connection((host, port), timeout=5)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to connect vpn shell to %s:%s", host, port, exc_info=True)
|
||||
return None
|
||||
session = ShellSession(sid=sid, agent_id=agent_id, socketio=self.socketio, tcp=tcp)
|
||||
self._sessions[sid] = session
|
||||
session.start_reader()
|
||||
return session
|
||||
|
||||
def send(self, sid: str, payload: str) -> None:
|
||||
session = self._sessions.get(sid)
|
||||
if not session:
|
||||
return
|
||||
session.send(payload)
|
||||
service = getattr(self.context, "vpn_tunnel_service", None)
|
||||
if service:
|
||||
service.bump_activity(session.agent_id)
|
||||
|
||||
def close(self, sid: str) -> None:
|
||||
session = self._sessions.pop(sid, None)
|
||||
if not session:
|
||||
return
|
||||
session.close()
|
||||
|
||||
@@ -8,13 +8,17 @@ import {
|
||||
Tab,
|
||||
Typography,
|
||||
Button,
|
||||
Switch,
|
||||
Chip,
|
||||
Divider,
|
||||
Menu,
|
||||
MenuItem,
|
||||
TextField,
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
DialogContent,
|
||||
DialogActions
|
||||
DialogActions,
|
||||
LinearProgress
|
||||
} from "@mui/material";
|
||||
import InfoOutlinedIcon from "@mui/icons-material/InfoOutlined";
|
||||
import StorageRoundedIcon from "@mui/icons-material/StorageRounded";
|
||||
@@ -23,6 +27,7 @@ import LanRoundedIcon from "@mui/icons-material/LanRounded";
|
||||
import AppsRoundedIcon from "@mui/icons-material/AppsRounded";
|
||||
import ListAltRoundedIcon from "@mui/icons-material/ListAltRounded";
|
||||
import TerminalRoundedIcon from "@mui/icons-material/TerminalRounded";
|
||||
import TuneRoundedIcon from "@mui/icons-material/TuneRounded";
|
||||
import SpeedRoundedIcon from "@mui/icons-material/SpeedRounded";
|
||||
import DeveloperBoardRoundedIcon from "@mui/icons-material/DeveloperBoardRounded";
|
||||
import MoreHorizIcon from "@mui/icons-material/MoreHoriz";
|
||||
@@ -69,14 +74,51 @@ const SECTION_HEIGHTS = {
|
||||
network: 260,
|
||||
};
|
||||
|
||||
const buildVpnGroups = (shellPort) => {
|
||||
const normalizedShell = Number(shellPort) || 47001;
|
||||
return [
|
||||
{
|
||||
key: "shell",
|
||||
label: "Borealis PowerShell",
|
||||
description: "Web terminal access over the VPN tunnel.",
|
||||
ports: [normalizedShell],
|
||||
},
|
||||
{
|
||||
key: "rdp",
|
||||
label: "RDP",
|
||||
description: "Remote Desktop (TCP 3389).",
|
||||
ports: [3389],
|
||||
},
|
||||
{
|
||||
key: "winrm",
|
||||
label: "WinRM",
|
||||
description: "PowerShell/WinRM management (TCP 5985/5986).",
|
||||
ports: [5985, 5986],
|
||||
},
|
||||
{
|
||||
key: "vnc",
|
||||
label: "VNC",
|
||||
description: "Remote desktop streaming (TCP 5900).",
|
||||
ports: [5900],
|
||||
},
|
||||
{
|
||||
key: "webrtc",
|
||||
label: "WebRTC",
|
||||
description: "Real-time comms (UDP 3478).",
|
||||
ports: [3478],
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
const TOP_TABS = [
|
||||
{ label: "Device Summary", icon: InfoOutlinedIcon },
|
||||
{ label: "Storage", icon: StorageRoundedIcon },
|
||||
{ label: "Memory", icon: MemoryRoundedIcon },
|
||||
{ label: "Network", icon: LanRoundedIcon },
|
||||
{ label: "Installed Software", icon: AppsRoundedIcon },
|
||||
{ label: "Activity History", icon: ListAltRoundedIcon },
|
||||
{ label: "Remote Shell", icon: TerminalRoundedIcon },
|
||||
{ key: "summary", label: "Device Summary", icon: InfoOutlinedIcon },
|
||||
{ key: "storage", label: "Storage", icon: StorageRoundedIcon },
|
||||
{ key: "memory", label: "Memory", icon: MemoryRoundedIcon },
|
||||
{ key: "network", label: "Network", icon: LanRoundedIcon },
|
||||
{ key: "software", label: "Installed Software", icon: AppsRoundedIcon },
|
||||
{ key: "activity", label: "Activity History", icon: ListAltRoundedIcon },
|
||||
{ key: "advanced", label: "Advanced Config", icon: TuneRoundedIcon },
|
||||
{ key: "shell", label: "Remote Shell", icon: TerminalRoundedIcon },
|
||||
];
|
||||
|
||||
const myTheme = themeQuartz.withParams({
|
||||
@@ -286,6 +328,15 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
const [menuAnchor, setMenuAnchor] = useState(null);
|
||||
const [clearDialogOpen, setClearDialogOpen] = useState(false);
|
||||
const [assemblyNameMap, setAssemblyNameMap] = useState({});
|
||||
const [vpnLoading, setVpnLoading] = useState(false);
|
||||
const [vpnSaving, setVpnSaving] = useState(false);
|
||||
const [vpnError, setVpnError] = useState("");
|
||||
const [vpnSource, setVpnSource] = useState("default");
|
||||
const [vpnToggles, setVpnToggles] = useState({});
|
||||
const [vpnCustomPorts, setVpnCustomPorts] = useState([]);
|
||||
const [vpnDefaultPorts, setVpnDefaultPorts] = useState([]);
|
||||
const [vpnShellPort, setVpnShellPort] = useState(47001);
|
||||
const [vpnLoadedFor, setVpnLoadedFor] = useState("");
|
||||
// Snapshotted status for the lifetime of this page
|
||||
const [lockedStatus, setLockedStatus] = useState(() => {
|
||||
// Prefer status provided by the device list row if available
|
||||
@@ -655,6 +706,104 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
};
|
||||
}, [activityHostname, loadHistory]);
|
||||
|
||||
const applyVpnPorts = useCallback((ports, defaults, shellPort, source) => {
|
||||
const normalized = Array.isArray(ports) ? ports : [];
|
||||
const normalizedDefaults = Array.isArray(defaults) ? defaults : [];
|
||||
const numericPorts = normalized
|
||||
.map((p) => Number(p))
|
||||
.filter((p) => Number.isFinite(p) && p > 0);
|
||||
const numericDefaults = normalizedDefaults
|
||||
.map((p) => Number(p))
|
||||
.filter((p) => Number.isFinite(p) && p > 0);
|
||||
const effectiveShell = Number(shellPort) || 47001;
|
||||
const groups = buildVpnGroups(effectiveShell);
|
||||
const knownPorts = new Set(groups.flatMap((group) => group.ports));
|
||||
const allowedSet = new Set(numericPorts);
|
||||
const nextToggles = {};
|
||||
groups.forEach((group) => {
|
||||
nextToggles[group.key] = group.ports.every((port) => allowedSet.has(port));
|
||||
});
|
||||
const customPorts = numericPorts.filter((port) => !knownPorts.has(port));
|
||||
setVpnShellPort(effectiveShell);
|
||||
setVpnDefaultPorts(numericDefaults);
|
||||
setVpnCustomPorts(customPorts);
|
||||
setVpnToggles(nextToggles);
|
||||
setVpnSource(source || "default");
|
||||
}, []);
|
||||
|
||||
const loadVpnConfig = useCallback(async () => {
|
||||
if (!vpnAgentId) return;
|
||||
setVpnLoading(true);
|
||||
setVpnError("");
|
||||
setVpnLoadedFor(vpnAgentId);
|
||||
try {
|
||||
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`);
|
||||
const data = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
|
||||
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : [];
|
||||
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : [];
|
||||
const shellPort = data?.shell_port;
|
||||
applyVpnPorts(allowedPorts.length ? allowedPorts : defaultPorts, defaultPorts, shellPort, data?.source);
|
||||
setVpnLoadedFor(vpnAgentId);
|
||||
} catch (err) {
|
||||
setVpnError(String(err.message || err));
|
||||
} finally {
|
||||
setVpnLoading(false);
|
||||
}
|
||||
}, [applyVpnPorts, vpnAgentId]);
|
||||
|
||||
const saveVpnConfig = useCallback(async () => {
|
||||
if (!vpnAgentId) return;
|
||||
const ports = [];
|
||||
vpnPortGroups.forEach((group) => {
|
||||
if (vpnToggles[group.key]) {
|
||||
ports.push(...group.ports);
|
||||
}
|
||||
});
|
||||
vpnCustomPorts.forEach((port) => ports.push(port));
|
||||
const uniquePorts = Array.from(new Set(ports)).filter((p) => p > 0);
|
||||
if (!uniquePorts.length) {
|
||||
setVpnError("Enable at least one port before saving.");
|
||||
return;
|
||||
}
|
||||
setVpnSaving(true);
|
||||
setVpnError("");
|
||||
try {
|
||||
const resp = await fetch(`/api/device/vpn_config/${encodeURIComponent(vpnAgentId)}`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ allowed_ports: uniquePorts }),
|
||||
});
|
||||
const data = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
|
||||
const allowedPorts = Array.isArray(data?.allowed_ports) ? data.allowed_ports : uniquePorts;
|
||||
const defaultPorts = Array.isArray(data?.default_ports) ? data.default_ports : vpnDefaultPorts;
|
||||
applyVpnPorts(allowedPorts, defaultPorts, data?.shell_port || vpnShellPort, data?.source || "custom");
|
||||
} catch (err) {
|
||||
setVpnError(String(err.message || err));
|
||||
} finally {
|
||||
setVpnSaving(false);
|
||||
}
|
||||
}, [applyVpnPorts, vpnAgentId, vpnCustomPorts, vpnDefaultPorts, vpnPortGroups, vpnShellPort, vpnToggles]);
|
||||
|
||||
const resetVpnConfig = useCallback(() => {
|
||||
if (!vpnDefaultPorts.length) {
|
||||
setVpnError("No default ports available to reset.");
|
||||
return;
|
||||
}
|
||||
setVpnError("");
|
||||
applyVpnPorts(vpnDefaultPorts, vpnDefaultPorts, vpnShellPort, "default");
|
||||
}, [applyVpnPorts, vpnDefaultPorts, vpnShellPort]);
|
||||
|
||||
useEffect(() => {
|
||||
const advancedIndex = TOP_TABS.findIndex((item) => item.key === "advanced");
|
||||
if (advancedIndex < 0) return;
|
||||
if (tab !== advancedIndex) return;
|
||||
if (!vpnAgentId) return;
|
||||
if (vpnLoadedFor === vpnAgentId) return;
|
||||
loadVpnConfig();
|
||||
}, [loadVpnConfig, tab, vpnAgentId, vpnLoadedFor]);
|
||||
|
||||
// No explicit live recap tab; recaps are recorded into Activity History
|
||||
|
||||
const clearHistory = async () => {
|
||||
@@ -739,6 +888,19 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
);
|
||||
|
||||
const summary = details.summary || {};
|
||||
const vpnAgentId = useMemo(() => {
|
||||
return (
|
||||
meta.agentId ||
|
||||
summary.agent_id ||
|
||||
agent?.agent_id ||
|
||||
agent?.id ||
|
||||
device?.agent_id ||
|
||||
device?.agent_guid ||
|
||||
device?.id ||
|
||||
""
|
||||
);
|
||||
}, [agent?.agent_id, agent?.id, device?.agent_guid, device?.agent_id, device?.id, meta.agentId, summary.agent_id]);
|
||||
const vpnPortGroups = useMemo(() => buildVpnGroups(vpnShellPort), [vpnShellPort]);
|
||||
const tunnelDevice = useMemo(
|
||||
() => ({
|
||||
...(device || {}),
|
||||
@@ -876,7 +1038,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
const formatScriptType = useCallback((raw) => {
|
||||
const value = String(raw || "").toLowerCase();
|
||||
if (value === "ansible") return "Ansible Playbook";
|
||||
if (value === "reverse_tunnel") return "Reverse Tunnel";
|
||||
if (value === "reverse_tunnel" || value === "vpn_tunnel") return "Reverse VPN Tunnel";
|
||||
return "Script";
|
||||
}, []);
|
||||
|
||||
@@ -1368,6 +1530,150 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
</Box>
|
||||
);
|
||||
|
||||
const handleVpnToggle = useCallback((key, checked) => {
|
||||
setVpnToggles((prev) => ({ ...(prev || {}), [key]: checked }));
|
||||
setVpnSource("custom");
|
||||
}, []);
|
||||
|
||||
const renderAdvancedConfigTab = () => {
|
||||
const sourceLabel = vpnSource === "custom" ? "Custom overrides" : "Defaults";
|
||||
const showProgress = vpnLoading || vpnSaving;
|
||||
return (
|
||||
<Box sx={{ display: "flex", flexDirection: "column", gap: 2, flexGrow: 1, minHeight: 0 }}>
|
||||
<Box
|
||||
sx={{
|
||||
borderRadius: 3,
|
||||
border: `1px solid ${MAGIC_UI.panelBorder}`,
|
||||
background:
|
||||
"linear-gradient(160deg, rgba(8,12,24,0.94), rgba(10,16,30,0.9)), radial-gradient(circle at 20% 10%, rgba(125,211,252,0.08), transparent 40%)",
|
||||
boxShadow: "0 25px 80px rgba(2,6,23,0.65)",
|
||||
p: { xs: 2, md: 3 },
|
||||
}}
|
||||
>
|
||||
{showProgress ? <LinearProgress color="info" sx={{ height: 3, mb: 2 }} /> : null}
|
||||
<Stack direction={{ xs: "column", md: "row" }} spacing={2} alignItems={{ xs: "flex-start", md: "center" }}>
|
||||
<Box sx={{ flexGrow: 1 }}>
|
||||
<Typography variant="h6" sx={{ color: MAGIC_UI.textBright, fontWeight: 700 }}>
|
||||
Reverse VPN Tunnel - Allowed Ports
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.5 }}>
|
||||
Toggle which services the Engine can reach over the WireGuard tunnel for this device.
|
||||
</Typography>
|
||||
</Box>
|
||||
<Chip
|
||||
label={sourceLabel}
|
||||
sx={{
|
||||
borderRadius: 999,
|
||||
fontWeight: 600,
|
||||
letterSpacing: 0.2,
|
||||
color: vpnSource === "custom" ? MAGIC_UI.accentA : MAGIC_UI.textMuted,
|
||||
border: `1px solid ${MAGIC_UI.panelBorder}`,
|
||||
backgroundColor: "rgba(8,12,24,0.75)",
|
||||
}}
|
||||
/>
|
||||
</Stack>
|
||||
<Divider sx={{ my: 2, borderColor: "rgba(148,163,184,0.2)" }} />
|
||||
<Stack spacing={1.5}>
|
||||
{vpnPortGroups.map((group) => (
|
||||
<Box
|
||||
key={group.key}
|
||||
sx={{
|
||||
display: "flex",
|
||||
alignItems: { xs: "flex-start", md: "center" },
|
||||
justifyContent: "space-between",
|
||||
gap: 2,
|
||||
p: 2,
|
||||
borderRadius: 2,
|
||||
border: `1px solid ${MAGIC_UI.panelBorder}`,
|
||||
background: "rgba(6,10,20,0.7)",
|
||||
}}
|
||||
>
|
||||
<Box sx={{ flexGrow: 1 }}>
|
||||
<Typography sx={{ color: MAGIC_UI.textBright, fontWeight: 600 }}>
|
||||
{group.label}
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted, mt: 0.35 }}>
|
||||
{group.description}
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={0.75} sx={{ mt: 0.8, flexWrap: "wrap" }}>
|
||||
{group.ports.map((port) => (
|
||||
<Chip
|
||||
key={`${group.key}-${port}`}
|
||||
label={`TCP ${port}`}
|
||||
size="small"
|
||||
sx={{
|
||||
borderRadius: 999,
|
||||
backgroundColor: "rgba(15,23,42,0.65)",
|
||||
color: MAGIC_UI.textMuted,
|
||||
border: `1px solid rgba(148,163,184,0.25)`,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
<Switch
|
||||
checked={Boolean(vpnToggles[group.key])}
|
||||
onChange={(event) => handleVpnToggle(group.key, event.target.checked)}
|
||||
color="info"
|
||||
disabled={vpnLoading || vpnSaving}
|
||||
/>
|
||||
</Box>
|
||||
))}
|
||||
</Stack>
|
||||
{vpnCustomPorts.length ? (
|
||||
<Box sx={{ mt: 2 }}>
|
||||
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
|
||||
Custom ports preserved: {vpnCustomPorts.join(", ")}
|
||||
</Typography>
|
||||
</Box>
|
||||
) : null}
|
||||
{vpnError ? (
|
||||
<Typography variant="body2" sx={{ color: "#ff7b89", mt: 1 }}>
|
||||
{vpnError}
|
||||
</Typography>
|
||||
) : null}
|
||||
<Stack direction="row" spacing={1.25} sx={{ mt: 2 }}>
|
||||
<Button
|
||||
size="small"
|
||||
disabled={!vpnAgentId || vpnSaving || vpnLoading}
|
||||
onClick={saveVpnConfig}
|
||||
sx={{
|
||||
backgroundImage: "linear-gradient(135deg,#7dd3fc,#c084fc)",
|
||||
color: "#0b1220",
|
||||
borderRadius: 999,
|
||||
textTransform: "none",
|
||||
px: 2.4,
|
||||
"&:hover": {
|
||||
backgroundImage: "linear-gradient(135deg,#86e1ff,#d1a6ff)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
Save Config
|
||||
</Button>
|
||||
<Button
|
||||
size="small"
|
||||
disabled={!vpnDefaultPorts.length || vpnSaving || vpnLoading}
|
||||
onClick={resetVpnConfig}
|
||||
sx={{
|
||||
borderRadius: 999,
|
||||
textTransform: "none",
|
||||
px: 2.4,
|
||||
color: MAGIC_UI.textBright,
|
||||
border: `1px solid ${MAGIC_UI.panelBorder}`,
|
||||
backgroundColor: "rgba(8,12,24,0.6)",
|
||||
"&:hover": {
|
||||
backgroundColor: "rgba(12,18,35,0.8)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
Reset Defaults
|
||||
</Button>
|
||||
</Stack>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const memoryRows = useMemo(
|
||||
() =>
|
||||
(details.memory || []).map((m, idx) => ({
|
||||
@@ -1618,6 +1924,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
renderNetworkTab,
|
||||
renderSoftware,
|
||||
renderHistory,
|
||||
renderAdvancedConfigTab,
|
||||
renderRemoteShellTab,
|
||||
];
|
||||
const tabContent = (topTabRenderers[tab] || renderDeviceSummaryTab)();
|
||||
@@ -1742,7 +2049,7 @@ export default function DeviceDetails({ device, onBack, onQuickJobLaunch, onPage
|
||||
>
|
||||
{TOP_TABS.map((tabDef) => (
|
||||
<Tab
|
||||
key={tabDef.label}
|
||||
key={tabDef.key || tabDef.label}
|
||||
label={tabDef.label}
|
||||
icon={<tabDef.icon sx={{ fontSize: 18 }} />}
|
||||
iconPosition="start"
|
||||
|
||||
@@ -5,17 +5,17 @@ import {
|
||||
Button,
|
||||
Stack,
|
||||
TextField,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
Tooltip,
|
||||
LinearProgress,
|
||||
Chip,
|
||||
} from "@mui/material";
|
||||
import {
|
||||
PlayArrowRounded as PlayIcon,
|
||||
StopRounded as StopIcon,
|
||||
ContentCopy as CopyIcon,
|
||||
RefreshRounded as RefreshIcon,
|
||||
LanRounded as PortIcon,
|
||||
LanRounded as IpIcon,
|
||||
LinkRounded as LinkIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { io } from "socket.io-client";
|
||||
@@ -24,18 +24,7 @@ import "prismjs/components/prism-powershell";
|
||||
import "prismjs/themes/prism-okaidia.css";
|
||||
import Editor from "react-simple-code-editor";
|
||||
|
||||
// Console diagnostics for troubleshooting the connect/disconnect flow.
|
||||
const debugLog = (...args) => {
|
||||
try {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error("[ReverseTunnel][PS]", ...args);
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
};
|
||||
|
||||
const MAGIC_UI = {
|
||||
panelBg: "rgba(7,11,24,0.92)",
|
||||
panelBorder: "rgba(148, 163, 184, 0.35)",
|
||||
textMuted: "#94a3b8",
|
||||
textBright: "#e2e8f0",
|
||||
@@ -56,13 +45,25 @@ const gradientButtonSx = {
|
||||
},
|
||||
};
|
||||
|
||||
const FRAME_HEADER_BYTES = 12; // version, msg_type, flags, reserved, channel_id(u32), length(u32)
|
||||
const MSG_CLOSE = 0x08;
|
||||
const CLOSE_AGENT_SHUTDOWN = 6;
|
||||
|
||||
const fontFamilyMono =
|
||||
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace';
|
||||
|
||||
const emitAsync = (socket, event, payload, timeoutMs = 4000) =>
|
||||
new Promise((resolve) => {
|
||||
let settled = false;
|
||||
const timer = setTimeout(() => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolve({ error: "timeout" });
|
||||
}, timeoutMs);
|
||||
socket.emit(event, payload, (resp) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
clearTimeout(timer);
|
||||
resolve(resp || {});
|
||||
});
|
||||
});
|
||||
|
||||
function normalizeText(value) {
|
||||
if (value == null) return "";
|
||||
try {
|
||||
@@ -72,28 +73,6 @@ function normalizeText(value) {
|
||||
}
|
||||
}
|
||||
|
||||
function base64FromBytes(bytes) {
|
||||
let binary = "";
|
||||
bytes.forEach((b) => {
|
||||
binary += String.fromCharCode(b);
|
||||
});
|
||||
return btoa(binary);
|
||||
}
|
||||
|
||||
function buildCloseFrame(channelId = 1, code = CLOSE_AGENT_SHUTDOWN, reason = "operator_close") {
|
||||
const payload = new TextEncoder().encode(JSON.stringify({ code, reason }));
|
||||
const buffer = new ArrayBuffer(FRAME_HEADER_BYTES + payload.length);
|
||||
const view = new DataView(buffer);
|
||||
view.setUint8(0, 1); // version
|
||||
view.setUint8(1, MSG_CLOSE);
|
||||
view.setUint8(2, 0); // flags
|
||||
view.setUint8(3, 0); // reserved
|
||||
view.setUint32(4, channelId >>> 0, true);
|
||||
view.setUint32(8, payload.length >>> 0, true);
|
||||
new Uint8Array(buffer, FRAME_HEADER_BYTES).set(payload);
|
||||
return base64FromBytes(new Uint8Array(buffer));
|
||||
}
|
||||
|
||||
function highlightPs(code) {
|
||||
try {
|
||||
return Prism.highlight(code || "", Prism.languages.powershell, "powershell");
|
||||
@@ -102,52 +81,18 @@ function highlightPs(code) {
|
||||
}
|
||||
}
|
||||
|
||||
const INITIAL_MILESTONES = {
|
||||
tunnelReady: false,
|
||||
operatorAttached: false,
|
||||
shellEstablished: false,
|
||||
};
|
||||
const INITIAL_STATUS_CHAIN = ["Offline"];
|
||||
|
||||
export default function ReverseTunnelPowershell({ device }) {
|
||||
const [connectionType, setConnectionType] = useState("ps");
|
||||
const [tunnel, setTunnel] = useState(null);
|
||||
const [sessionState, setSessionState] = useState("idle");
|
||||
const [, setStatusMessage] = useState("");
|
||||
const [, setStatusSeverity] = useState("info");
|
||||
const [shellState, setShellState] = useState("idle");
|
||||
const [tunnel, setTunnel] = useState(null);
|
||||
const [output, setOutput] = useState("");
|
||||
const [input, setInput] = useState("");
|
||||
const [statusMessage, setStatusMessage] = useState("");
|
||||
const [copyFlash, setCopyFlash] = useState(false);
|
||||
const [, setPolling] = useState(false);
|
||||
const [psStatus, setPsStatus] = useState({});
|
||||
const [milestones, setMilestones] = useState(() => ({ ...INITIAL_MILESTONES }));
|
||||
const [tunnelSteps, setTunnelSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
|
||||
const [websocketSteps, setWebsocketSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
|
||||
const [shellSteps, setShellSteps] = useState(() => [...INITIAL_STATUS_CHAIN]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const socketRef = useRef(null);
|
||||
const pollTimerRef = useRef(null);
|
||||
const resizeTimerRef = useRef(null);
|
||||
const localSocketRef = useRef(false);
|
||||
const terminalRef = useRef(null);
|
||||
const joinRetryRef = useRef(null);
|
||||
const joinAttemptsRef = useRef(0);
|
||||
const tunnelRef = useRef(null);
|
||||
const shellFlagsRef = useRef({ openSent: false, ack: false });
|
||||
const DOMAIN_REMOTE_SHELL = "remote-interactive-shell";
|
||||
|
||||
useEffect(() => {
|
||||
debugLog("component mount", { hostname: device?.hostname, agentId });
|
||||
return () => debugLog("component unmount");
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
const hostname = useMemo(() => {
|
||||
return (
|
||||
normalizeText(device?.hostname) ||
|
||||
normalizeText(device?.summary?.hostname) ||
|
||||
normalizeText(device?.agent_hostname) ||
|
||||
""
|
||||
);
|
||||
}, [device]);
|
||||
|
||||
const agentId = useMemo(() => {
|
||||
return (
|
||||
@@ -162,78 +107,18 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
);
|
||||
}, [device]);
|
||||
|
||||
const appendStatus = useCallback((setter, label) => {
|
||||
if (!label) return;
|
||||
setter((prev) => {
|
||||
const next = [...prev, label];
|
||||
const cap = 6;
|
||||
return next.length > cap ? next.slice(next.length - cap) : next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const resetState = useCallback(() => {
|
||||
debugLog("resetState invoked");
|
||||
setTunnel(null);
|
||||
setSessionState("idle");
|
||||
setStatusMessage("");
|
||||
setStatusSeverity("info");
|
||||
setOutput("");
|
||||
setInput("");
|
||||
setPsStatus({});
|
||||
setMilestones({ ...INITIAL_MILESTONES });
|
||||
setTunnelSteps([...INITIAL_STATUS_CHAIN]);
|
||||
setWebsocketSteps([...INITIAL_STATUS_CHAIN]);
|
||||
setShellSteps([...INITIAL_STATUS_CHAIN]);
|
||||
shellFlagsRef.current = { openSent: false, ack: false };
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
tunnelRef.current = tunnel?.tunnel_id || null;
|
||||
}, [tunnel?.tunnel_id]);
|
||||
|
||||
const disconnectSocket = useCallback(() => {
|
||||
const socket = socketRef.current;
|
||||
if (socket) {
|
||||
socket.off();
|
||||
socket.disconnect();
|
||||
const ensureSocket = useCallback(() => {
|
||||
if (socketRef.current) return socketRef.current;
|
||||
const existing = typeof window !== "undefined" ? window.BorealisSocket : null;
|
||||
if (existing) {
|
||||
socketRef.current = existing;
|
||||
localSocketRef.current = false;
|
||||
return existing;
|
||||
}
|
||||
socketRef.current = null;
|
||||
}, []);
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (pollTimerRef.current) {
|
||||
clearTimeout(pollTimerRef.current);
|
||||
pollTimerRef.current = null;
|
||||
}
|
||||
setPolling(false);
|
||||
}, []);
|
||||
|
||||
const stopTunnel = useCallback(async (reason = "operator_disconnect", tunnelIdOverride = null) => {
|
||||
const tunnelId = tunnelIdOverride || tunnelRef.current;
|
||||
if (!tunnelId) return;
|
||||
try {
|
||||
await fetch(`/api/tunnel/${tunnelId}`, {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ reason }),
|
||||
});
|
||||
} catch (err) {
|
||||
// best-effort; socket close frame acts as fallback
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
debugLog("cleanup on unmount", { tunnelId: tunnelRef.current });
|
||||
stopPolling();
|
||||
disconnectSocket();
|
||||
if (joinRetryRef.current) {
|
||||
clearTimeout(joinRetryRef.current);
|
||||
joinRetryRef.current = null;
|
||||
}
|
||||
stopTunnel("component_unmount", tunnelRef.current);
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
const socket = io(window.location.origin, { transports: ["websocket"] });
|
||||
socketRef.current = socket;
|
||||
localSocketRef.current = true;
|
||||
return socket;
|
||||
}, []);
|
||||
|
||||
const appendOutput = useCallback((text) => {
|
||||
@@ -257,6 +142,137 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
scrollToBottom();
|
||||
}, [output, scrollToBottom]);
|
||||
|
||||
const stopTunnel = useCallback(
|
||||
async (reason = "operator_disconnect") => {
|
||||
if (!agentId) return;
|
||||
try {
|
||||
await fetch("/api/tunnel/disconnect", {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ agent_id: agentId, tunnel_id: tunnel?.tunnel_id, reason }),
|
||||
});
|
||||
} catch {
|
||||
// best-effort
|
||||
}
|
||||
},
|
||||
[agentId, tunnel?.tunnel_id]
|
||||
);
|
||||
|
||||
const closeShell = useCallback(async () => {
|
||||
const socket = ensureSocket();
|
||||
await emitAsync(socket, "vpn_shell_close", {});
|
||||
}, [ensureSocket]);
|
||||
|
||||
const handleDisconnect = useCallback(async () => {
|
||||
setLoading(true);
|
||||
setStatusMessage("");
|
||||
try {
|
||||
await closeShell();
|
||||
await stopTunnel("operator_disconnect");
|
||||
} finally {
|
||||
setTunnel(null);
|
||||
setShellState("closed");
|
||||
setSessionState("idle");
|
||||
setLoading(false);
|
||||
}
|
||||
}, [closeShell, stopTunnel]);
|
||||
|
||||
useEffect(() => {
|
||||
const socket = ensureSocket();
|
||||
const handleDisconnectEvent = () => {
|
||||
if (sessionState === "connected") {
|
||||
setShellState("closed");
|
||||
setSessionState("idle");
|
||||
setStatusMessage("Socket disconnected.");
|
||||
}
|
||||
};
|
||||
const handleOutput = (payload) => {
|
||||
appendOutput(payload?.data || "");
|
||||
};
|
||||
const handleClosed = () => {
|
||||
setShellState("closed");
|
||||
setSessionState("idle");
|
||||
setStatusMessage("Shell closed.");
|
||||
};
|
||||
|
||||
socket.on("disconnect", handleDisconnectEvent);
|
||||
socket.on("vpn_shell_output", handleOutput);
|
||||
socket.on("vpn_shell_closed", handleClosed);
|
||||
|
||||
return () => {
|
||||
socket.off("disconnect", handleDisconnectEvent);
|
||||
socket.off("vpn_shell_output", handleOutput);
|
||||
socket.off("vpn_shell_closed", handleClosed);
|
||||
if (localSocketRef.current) {
|
||||
socket.disconnect();
|
||||
}
|
||||
};
|
||||
}, [appendOutput, ensureSocket, sessionState]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
closeShell();
|
||||
stopTunnel("component_unmount");
|
||||
};
|
||||
}, [closeShell, stopTunnel]);
|
||||
|
||||
const requestTunnel = useCallback(async () => {
|
||||
if (!agentId) {
|
||||
setStatusMessage("Agent ID is required to connect.");
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
setStatusMessage("");
|
||||
setSessionState("connecting");
|
||||
setShellState("opening");
|
||||
try {
|
||||
const resp = await fetch("/api/tunnel/connect", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ agent_id: agentId }),
|
||||
});
|
||||
const data = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok) throw new Error(data?.error || `HTTP ${resp.status}`);
|
||||
const statusResp = await fetch(
|
||||
`/api/tunnel/connect/status?agent_id=${encodeURIComponent(agentId)}&bump=1`
|
||||
);
|
||||
const statusData = await statusResp.json().catch(() => ({}));
|
||||
if (!statusResp.ok || statusData?.status !== "up") {
|
||||
throw new Error(statusData?.error || "Tunnel not ready");
|
||||
}
|
||||
setTunnel({ ...data, ...statusData });
|
||||
|
||||
const socket = ensureSocket();
|
||||
const openResp = await emitAsync(socket, "vpn_shell_open", { agent_id: agentId }, 6000);
|
||||
if (openResp?.error) {
|
||||
throw new Error(openResp.error);
|
||||
}
|
||||
setSessionState("connected");
|
||||
setShellState("connected");
|
||||
} catch (err) {
|
||||
setSessionState("error");
|
||||
setShellState("closed");
|
||||
setStatusMessage(String(err.message || err));
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [agentId, ensureSocket]);
|
||||
|
||||
const handleSend = useCallback(
|
||||
async (text) => {
|
||||
const socket = ensureSocket();
|
||||
if (!socket || sessionState !== "connected") return;
|
||||
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
|
||||
appendOutput(`\nPS> ${text}\n`);
|
||||
setInput("");
|
||||
const resp = await emitAsync(socket, "vpn_shell_send", { data: payload });
|
||||
if (resp?.error) {
|
||||
setStatusMessage("Send failed.");
|
||||
}
|
||||
},
|
||||
[appendOutput, ensureSocket, sessionState]
|
||||
);
|
||||
|
||||
const handleCopy = async () => {
|
||||
try {
|
||||
await navigator.clipboard.writeText(output || "");
|
||||
@@ -267,329 +283,7 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
}
|
||||
};
|
||||
|
||||
const measureTerminal = useCallback(() => {
|
||||
const el = terminalRef.current;
|
||||
if (!el) return { cols: 120, rows: 32 };
|
||||
const width = el.clientWidth || 960;
|
||||
const height = el.clientHeight || 460;
|
||||
const charWidth = 8.2;
|
||||
const charHeight = 18;
|
||||
const cols = Math.max(20, Math.min(Math.floor(width / charWidth), 300));
|
||||
const rows = Math.max(10, Math.min(Math.floor(height / charHeight), 200));
|
||||
return { cols, rows };
|
||||
}, []);
|
||||
|
||||
const emitAsync = useCallback((socket, event, payload, timeoutMs = 4000) => {
|
||||
return new Promise((resolve) => {
|
||||
let settled = false;
|
||||
const timer = setTimeout(() => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
resolve({ error: "timeout" });
|
||||
}, timeoutMs);
|
||||
socket.emit(event, payload, (resp) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
clearTimeout(timer);
|
||||
resolve(resp || {});
|
||||
});
|
||||
});
|
||||
}, []);
|
||||
|
||||
const pollLoop = useCallback(
|
||||
(socket, tunnelId) => {
|
||||
if (!socket || !tunnelId) return;
|
||||
debugLog("pollLoop tick", { tunnelId });
|
||||
setPolling(true);
|
||||
pollTimerRef.current = setTimeout(async () => {
|
||||
const resp = await emitAsync(socket, "ps_poll", {});
|
||||
if (resp?.error) {
|
||||
debugLog("pollLoop error", resp);
|
||||
stopPolling();
|
||||
disconnectSocket();
|
||||
setPsStatus({});
|
||||
setTunnel(null);
|
||||
setSessionState("error");
|
||||
return;
|
||||
}
|
||||
if (Array.isArray(resp?.output) && resp.output.length) {
|
||||
appendOutput(resp.output.join(""));
|
||||
}
|
||||
if (resp?.status) {
|
||||
setPsStatus(resp.status);
|
||||
debugLog("pollLoop status", resp.status);
|
||||
if (resp.status.closed) {
|
||||
setSessionState("closed");
|
||||
setTunnel(null);
|
||||
setMilestones({ ...INITIAL_MILESTONES });
|
||||
appendStatus(setShellSteps, "Shell closed");
|
||||
appendStatus(setTunnelSteps, "Stopped");
|
||||
appendStatus(setWebsocketSteps, "Relay stopped");
|
||||
shellFlagsRef.current = { openSent: false, ack: false };
|
||||
stopPolling();
|
||||
return;
|
||||
}
|
||||
if (resp.status.open_sent && !shellFlagsRef.current.openSent) {
|
||||
appendStatus(setShellSteps, "Opening remote shell");
|
||||
shellFlagsRef.current.openSent = true;
|
||||
}
|
||||
if (resp.status.ack && !shellFlagsRef.current.ack) {
|
||||
setSessionState("connected");
|
||||
setMilestones((prev) => ({ ...prev, shellEstablished: true }));
|
||||
appendStatus(setShellSteps, "Remote shell established");
|
||||
shellFlagsRef.current.ack = true;
|
||||
}
|
||||
}
|
||||
pollLoop(socket, tunnelId);
|
||||
}, 520);
|
||||
},
|
||||
[appendOutput, emitAsync, stopPolling, disconnectSocket, appendStatus]
|
||||
);
|
||||
|
||||
const handleDisconnect = useCallback(
|
||||
async (reason = "operator_disconnect") => {
|
||||
debugLog("handleDisconnect begin", { reason, tunnelId: tunnel?.tunnel_id, psStatus, sessionState });
|
||||
setPsStatus({});
|
||||
const socket = socketRef.current;
|
||||
const tunnelId = tunnel?.tunnel_id;
|
||||
if (joinRetryRef.current) {
|
||||
clearTimeout(joinRetryRef.current);
|
||||
joinRetryRef.current = null;
|
||||
}
|
||||
joinAttemptsRef.current = 0;
|
||||
if (socket && tunnelId) {
|
||||
const frame = buildCloseFrame(1, CLOSE_AGENT_SHUTDOWN, "operator_close");
|
||||
debugLog("emit CLOSE", { tunnelId });
|
||||
socket.emit("send", { frame });
|
||||
}
|
||||
await stopTunnel(reason);
|
||||
debugLog("stopTunnel issued", { tunnelId });
|
||||
stopPolling();
|
||||
disconnectSocket();
|
||||
setTunnel(null);
|
||||
setSessionState("closed");
|
||||
setMilestones({ ...INITIAL_MILESTONES });
|
||||
appendStatus(setTunnelSteps, "Stopped");
|
||||
appendStatus(setWebsocketSteps, "Relay closed");
|
||||
appendStatus(setShellSteps, "Shell closed");
|
||||
shellFlagsRef.current = { openSent: false, ack: false };
|
||||
debugLog("handleDisconnect finished", { tunnelId });
|
||||
},
|
||||
[appendStatus, disconnectSocket, stopPolling, stopTunnel, tunnel?.tunnel_id]
|
||||
);
|
||||
|
||||
const handleResize = useCallback(() => {
|
||||
if (!socketRef.current || sessionState === "idle") return;
|
||||
const dims = measureTerminal();
|
||||
socketRef.current.emit("ps_resize", dims);
|
||||
}, [measureTerminal, sessionState]);
|
||||
|
||||
useEffect(() => {
|
||||
const observer =
|
||||
typeof ResizeObserver !== "undefined"
|
||||
? new ResizeObserver(() => {
|
||||
if (resizeTimerRef.current) clearTimeout(resizeTimerRef.current);
|
||||
resizeTimerRef.current = setTimeout(() => handleResize(), 200);
|
||||
})
|
||||
: null;
|
||||
const el = terminalRef.current;
|
||||
if (observer && el) observer.observe(el);
|
||||
const onWinResize = () => handleResize();
|
||||
window.addEventListener("resize", onWinResize);
|
||||
return () => {
|
||||
window.removeEventListener("resize", onWinResize);
|
||||
if (observer && el) observer.unobserve(el);
|
||||
};
|
||||
}, [handleResize]);
|
||||
|
||||
const connectSocket = useCallback(
|
||||
(lease, { isRetry = false } = {}) => {
|
||||
if (!lease?.tunnel_id) return;
|
||||
if (joinRetryRef.current) {
|
||||
clearTimeout(joinRetryRef.current);
|
||||
joinRetryRef.current = null;
|
||||
}
|
||||
if (!isRetry) {
|
||||
joinAttemptsRef.current = 0;
|
||||
}
|
||||
disconnectSocket();
|
||||
stopPolling();
|
||||
setSessionState("waiting");
|
||||
const socket = io(`${window.location.origin}/tunnel`, { transports: ["websocket", "polling"] });
|
||||
socketRef.current = socket;
|
||||
|
||||
socket.on("connect_error", () => {
|
||||
debugLog("socket connect_error");
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Tunnel namespace unavailable.");
|
||||
setTunnel(null);
|
||||
setSessionState("error");
|
||||
appendStatus(setWebsocketSteps, "Relay connect error");
|
||||
});
|
||||
|
||||
socket.on("disconnect", () => {
|
||||
debugLog("socket disconnect", { tunnelId: tunnel?.tunnel_id });
|
||||
stopPolling();
|
||||
if (sessionState !== "closed") {
|
||||
setSessionState("disconnected");
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Socket disconnected.");
|
||||
setTunnel(null);
|
||||
appendStatus(setWebsocketSteps, "Relay disconnected");
|
||||
}
|
||||
});
|
||||
|
||||
socket.on("connect", async () => {
|
||||
debugLog("socket connect", { tunnelId: lease.tunnel_id });
|
||||
setMilestones((prev) => ({ ...prev, operatorAttached: true }));
|
||||
setStatusSeverity("info");
|
||||
setStatusMessage("Joining tunnel...");
|
||||
appendStatus(setWebsocketSteps, "Relay connected");
|
||||
const joinResp = await emitAsync(socket, "join", { tunnel_id: lease.tunnel_id }, 5000);
|
||||
if (joinResp?.error) {
|
||||
const attempt = (joinAttemptsRef.current += 1);
|
||||
const isTimeout = joinResp.error === "timeout";
|
||||
if (joinResp.error === "unknown_tunnel") {
|
||||
setSessionState("waiting_agent");
|
||||
setStatusSeverity("info");
|
||||
setStatusMessage("Waiting for agent to establish tunnel...");
|
||||
appendStatus(setWebsocketSteps, "Waiting for agent");
|
||||
} else if (isTimeout || joinResp.error === "attach_failed") {
|
||||
setSessionState("waiting_agent");
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Tunnel join timed out. Retrying...");
|
||||
appendStatus(setWebsocketSteps, `Join retry ${attempt}`);
|
||||
} else {
|
||||
debugLog("join error", joinResp);
|
||||
setSessionState("error");
|
||||
setStatusSeverity("error");
|
||||
setStatusMessage(joinResp.error);
|
||||
appendStatus(setWebsocketSteps, `Join failed: ${joinResp.error}`);
|
||||
return;
|
||||
}
|
||||
if (attempt <= 5) {
|
||||
joinRetryRef.current = setTimeout(() => connectSocket(lease, { isRetry: true }), 800);
|
||||
} else {
|
||||
setSessionState("error");
|
||||
setTunnel(null);
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Operator could not attach to tunnel. Try Connect again.");
|
||||
appendStatus(setWebsocketSteps, "Join failed after retries");
|
||||
}
|
||||
return;
|
||||
}
|
||||
appendStatus(setWebsocketSteps, "Relay joined");
|
||||
const dims = measureTerminal();
|
||||
debugLog("ps_open emit", { tunnelId: lease.tunnel_id, dims });
|
||||
const openResp = await emitAsync(socket, "ps_open", dims, 5000);
|
||||
if (openResp?.error && openResp.error === "ps_unsupported") {
|
||||
// Suppress warming message; channel will settle once agent attaches.
|
||||
}
|
||||
if (!shellFlagsRef.current.openSent) {
|
||||
appendStatus(setShellSteps, "Opening remote shell");
|
||||
shellFlagsRef.current.openSent = true;
|
||||
}
|
||||
appendOutput("");
|
||||
setSessionState("waiting_agent");
|
||||
pollLoop(socket, lease.tunnel_id);
|
||||
handleResize();
|
||||
});
|
||||
},
|
||||
[appendOutput, appendStatus, disconnectSocket, emitAsync, handleResize, measureTerminal, pollLoop, sessionState, stopPolling]
|
||||
);
|
||||
|
||||
const requestTunnel = useCallback(async () => {
|
||||
if (tunnel && sessionState !== "closed" && sessionState !== "idle") {
|
||||
setStatusSeverity("info");
|
||||
setStatusMessage("");
|
||||
connectSocket(tunnel);
|
||||
return;
|
||||
}
|
||||
debugLog("requestTunnel", { agentId, connectionType });
|
||||
if (!agentId) {
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Agent ID is required to request a tunnel.");
|
||||
return;
|
||||
}
|
||||
if (connectionType !== "ps") {
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("Only PowerShell is supported right now.");
|
||||
return;
|
||||
}
|
||||
resetState();
|
||||
setSessionState("requesting");
|
||||
setStatusSeverity("info");
|
||||
setStatusMessage("");
|
||||
appendStatus(setTunnelSteps, "Requesting lease");
|
||||
try {
|
||||
const resp = await fetch("/api/tunnel/request", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ agent_id: agentId, protocol: "ps", domain: DOMAIN_REMOTE_SHELL }),
|
||||
});
|
||||
const data = await resp.json().catch(() => ({}));
|
||||
if (!resp.ok) {
|
||||
const err = data?.error || `HTTP ${resp.status}`;
|
||||
setSessionState("error");
|
||||
setStatusSeverity(err === "domain_limit" ? "warning" : "error");
|
||||
setStatusMessage("");
|
||||
return;
|
||||
}
|
||||
setMilestones((prev) => ({ ...prev, tunnelReady: true }));
|
||||
setTunnel(data);
|
||||
setStatusMessage("");
|
||||
setSessionState("lease_issued");
|
||||
appendStatus(
|
||||
setTunnelSteps,
|
||||
data?.tunnel_id
|
||||
? `Lease issued (${data.tunnel_id.slice(0, 8)} @ Port ${data.port || "-"})`
|
||||
: "Lease issued"
|
||||
);
|
||||
connectSocket(data);
|
||||
} catch (e) {
|
||||
setSessionState("error");
|
||||
setStatusSeverity("error");
|
||||
setStatusMessage("");
|
||||
appendStatus(setTunnelSteps, "Lease request failed");
|
||||
}
|
||||
}, [DOMAIN_REMOTE_SHELL, agentId, appendStatus, connectSocket, connectionType, resetState]);
|
||||
|
||||
const handleSend = useCallback(
|
||||
async (text) => {
|
||||
const socket = socketRef.current;
|
||||
if (!socket) return;
|
||||
const payload = `${text}${text.endsWith("\n") ? "" : "\r\n"}`;
|
||||
appendOutput(`\nPS> ${text}\n`);
|
||||
setInput("");
|
||||
const resp = await emitAsync(socket, "ps_send", { data: payload });
|
||||
if (resp?.error) {
|
||||
setStatusSeverity("warning");
|
||||
setStatusMessage("");
|
||||
}
|
||||
},
|
||||
[appendOutput, emitAsync]
|
||||
);
|
||||
|
||||
const isConnected = sessionState === "connected" || (psStatus?.ack && !psStatus?.closed);
|
||||
const isClosed = sessionState === "closed" || psStatus?.closed;
|
||||
const isBusy =
|
||||
sessionState === "requesting" ||
|
||||
sessionState === "waiting" ||
|
||||
sessionState === "waiting_agent" ||
|
||||
sessionState === "lease_issued";
|
||||
const canStart = Boolean(agentId) && !isBusy;
|
||||
|
||||
useEffect(() => {
|
||||
const handleUnload = () => {
|
||||
stopTunnel("window_unload");
|
||||
};
|
||||
if (tunnel?.tunnel_id) {
|
||||
window.addEventListener("beforeunload", handleUnload);
|
||||
return () => window.removeEventListener("beforeunload", handleUnload);
|
||||
}
|
||||
return undefined;
|
||||
}, [stopTunnel, tunnel?.tunnel_id]);
|
||||
|
||||
const isConnected = sessionState === "connected";
|
||||
const sessionChips = [
|
||||
tunnel?.tunnel_id
|
||||
? {
|
||||
@@ -598,58 +292,43 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
icon: <LinkIcon sx={{ fontSize: 18 }} />,
|
||||
}
|
||||
: null,
|
||||
tunnel?.port
|
||||
tunnel?.virtual_ip
|
||||
? {
|
||||
label: `Port ${tunnel.port}`,
|
||||
label: `IP ${String(tunnel.virtual_ip).split("/")[0]}`,
|
||||
color: MAGIC_UI.accentA,
|
||||
icon: <PortIcon sx={{ fontSize: 18 }} />,
|
||||
icon: <IpIcon sx={{ fontSize: 18 }} />,
|
||||
}
|
||||
: null,
|
||||
].filter(Boolean);
|
||||
|
||||
return (
|
||||
<Box sx={{ display: "flex", flexDirection: "column", gap: 1.5, flexGrow: 1, minHeight: 0 }}>
|
||||
<Box>
|
||||
<Stack
|
||||
direction={{ xs: "column", sm: "row" }}
|
||||
spacing={1.5}
|
||||
alignItems={{ xs: "flex-start", sm: "center" }}
|
||||
justifyContent={{ xs: "flex-start", sm: "flex-end" }}
|
||||
<Stack direction={{ xs: "column", sm: "row" }} spacing={1.5} alignItems={{ xs: "flex-start", sm: "center" }}>
|
||||
<Button
|
||||
size="small"
|
||||
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
|
||||
sx={gradientButtonSx}
|
||||
disabled={loading || (!isConnected && !agentId)}
|
||||
onClick={isConnected ? handleDisconnect : requestTunnel}
|
||||
>
|
||||
<TextField
|
||||
select
|
||||
label="Connection Protocol"
|
||||
size="small"
|
||||
value={connectionType}
|
||||
onChange={(e) => setConnectionType(e.target.value)}
|
||||
sx={{
|
||||
minWidth: 180,
|
||||
"& .MuiInputBase-root": {
|
||||
backgroundColor: "rgba(12,18,35,0.85)",
|
||||
color: MAGIC_UI.textBright,
|
||||
borderRadius: 1.5,
|
||||
},
|
||||
"& fieldset": { borderColor: MAGIC_UI.panelBorder },
|
||||
"&:hover fieldset": { borderColor: MAGIC_UI.accentA },
|
||||
}}
|
||||
>
|
||||
<MenuItem value="ps">PowerShell</MenuItem>
|
||||
</TextField>
|
||||
<Tooltip title={isConnected ? "Disconnect session" : "Connect to agent"}>
|
||||
<span>
|
||||
<Button
|
||||
size="small"
|
||||
startIcon={isConnected ? <StopIcon /> : <PlayIcon />}
|
||||
sx={gradientButtonSx}
|
||||
disabled={!isConnected && !canStart}
|
||||
onClick={isConnected ? handleDisconnect : requestTunnel}
|
||||
>
|
||||
{isConnected ? "Disconnect" : "Connect"}
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
{isConnected ? "Disconnect" : "Connect"}
|
||||
</Button>
|
||||
<Stack direction="row" spacing={1}>
|
||||
{sessionChips.map((chip) => (
|
||||
<Chip
|
||||
key={chip.label}
|
||||
icon={chip.icon}
|
||||
label={chip.label}
|
||||
sx={{
|
||||
borderRadius: 999,
|
||||
color: chip.color,
|
||||
border: `1px solid ${MAGIC_UI.panelBorder}`,
|
||||
backgroundColor: "rgba(8,12,24,0.65)",
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
<Box
|
||||
sx={{
|
||||
@@ -665,7 +344,7 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
overflow: "hidden",
|
||||
}}
|
||||
>
|
||||
{isBusy ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
|
||||
{loading ? <LinearProgress color="info" sx={{ height: 3 }} /> : null}
|
||||
<Box
|
||||
ref={terminalRef}
|
||||
sx={{
|
||||
@@ -728,11 +407,7 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
size="small"
|
||||
value={input}
|
||||
disabled={!isConnected}
|
||||
placeholder={
|
||||
isConnected
|
||||
? "Enter PowerShell command and press Enter"
|
||||
: "Connect to start sending commands"
|
||||
}
|
||||
placeholder={isConnected ? "Enter PowerShell command and press Enter" : "Connect to start sending commands"}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
@@ -753,43 +428,19 @@ export default function ReverseTunnelPowershell({ device }) {
|
||||
/>
|
||||
</Box>
|
||||
</Box>
|
||||
<Stack spacing={0.3} sx={{ mt: 1.25 }}>
|
||||
<Typography
|
||||
variant="body2"
|
||||
sx={{
|
||||
color: milestones.tunnelReady ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
|
||||
fontWeight: 700,
|
||||
}}
|
||||
>
|
||||
Tunnel:{" "}
|
||||
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
|
||||
{tunnelSteps.join(" > ")}
|
||||
</Typography>
|
||||
|
||||
<Stack spacing={0.3} sx={{ mt: 1 }}>
|
||||
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
|
||||
Tunnel: {sessionState === "connected" ? "Active" : sessionState}
|
||||
</Typography>
|
||||
<Typography
|
||||
variant="body2"
|
||||
sx={{
|
||||
color: milestones.operatorAttached ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
|
||||
fontWeight: 700,
|
||||
}}
|
||||
>
|
||||
Websocket:{" "}
|
||||
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
|
||||
{websocketSteps.join(" > ")}
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: MAGIC_UI.textMuted }}>
|
||||
Shell: {shellState === "connected" ? "Ready" : shellState}
|
||||
</Typography>
|
||||
<Typography
|
||||
variant="body2"
|
||||
sx={{
|
||||
color: milestones.shellEstablished ? MAGIC_UI.accentC : MAGIC_UI.textMuted,
|
||||
fontWeight: 700,
|
||||
}}
|
||||
>
|
||||
Remote Shell:{" "}
|
||||
<Typography component="span" variant="body2" sx={{ color: MAGIC_UI.textMuted, fontWeight: 500 }}>
|
||||
{shellSteps.join(" > ")}
|
||||
{statusMessage ? (
|
||||
<Typography variant="body2" sx={{ color: "#ff7b89" }}>
|
||||
{statusMessage}
|
||||
</Typography>
|
||||
</Typography>
|
||||
) : null}
|
||||
</Stack>
|
||||
</Box>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user