mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-19 05:15:48 -07:00
Overhaul of VPN Codebase
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
"""Reverse tunnel protocol modules (placeholder package)."""
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
"""PowerShell channel implementation for reverse tunnel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
|
||||
MSG_DATA = 0x05
|
||||
MSG_WINDOW_UPDATE = 0x06
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
|
||||
# Close codes (mirrored from engine framing)
|
||||
CLOSE_OK = 0
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class PowershellChannel:
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
self._reader_task = None
|
||||
self._writer_task = None
|
||||
self._stdin_queue: asyncio.Queue = asyncio.Queue()
|
||||
self._proc: Optional[asyncio.subprocess.Process] = None
|
||||
self._exit_code: Optional[int] = None
|
||||
self._frame_cls = getattr(role, "_frame_cls", None)
|
||||
|
||||
# ------------------------------------------------------------------ Helpers
|
||||
def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0):
|
||||
frame_cls = self._frame_cls
|
||||
if frame_cls is None:
|
||||
return None
|
||||
try:
|
||||
return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _send_frame(self, frame) -> None:
|
||||
if frame is None:
|
||||
return
|
||||
await self.role._send_frame(self.tunnel, frame)
|
||||
|
||||
async def _send_close(self, code: int, reason: str) -> None:
|
||||
try:
|
||||
close_frame = getattr(self.role, "close_frame")
|
||||
if callable(close_frame):
|
||||
await self._send_frame(close_frame(self.channel_id, code, reason))
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
frame = self._make_frame(
|
||||
MSG_CLOSE,
|
||||
payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"),
|
||||
)
|
||||
await self._send_frame(frame)
|
||||
|
||||
def _powershell_argv(self) -> list:
|
||||
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
|
||||
shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe"
|
||||
# Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin.
|
||||
return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"]
|
||||
|
||||
# ------------------------------------------------------------------ Lifecycle
|
||||
async def start(self) -> None:
|
||||
if sys.platform.lower().startswith("win") is False:
|
||||
self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True)
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
|
||||
return
|
||||
|
||||
argv = self._powershell_argv()
|
||||
self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} mode=pipes")
|
||||
|
||||
# Pipes (no PTY).
|
||||
try:
|
||||
self._proc = await asyncio.create_subprocess_exec(
|
||||
*argv,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
)
|
||||
except Exception as exc:
|
||||
self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True)
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
|
||||
return
|
||||
|
||||
self._reader_task = self.loop.create_task(self._pump_proc_stdout())
|
||||
self._writer_task = self.loop.create_task(self._pump_proc_stdin())
|
||||
self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)}")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type == MSG_DATA:
|
||||
if frame.payload:
|
||||
try:
|
||||
self._stdin_queue.put_nowait(frame.payload)
|
||||
except Exception:
|
||||
await self._stdin_queue.put(frame.payload)
|
||||
elif frame.msg_type == MSG_CONTROL:
|
||||
await self._handle_control(frame.payload)
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
elif frame.msg_type == MSG_WINDOW_UPDATE:
|
||||
# Reserved for back-pressure; ignore for now.
|
||||
return
|
||||
|
||||
async def _handle_control(self, payload: bytes) -> None:
|
||||
# No-op for pipe mode; resize is not supported here.
|
||||
return
|
||||
|
||||
# -------------------- Pipe fallback pumps --------------------
|
||||
async def _pump_proc_stdout(self) -> None:
|
||||
try:
|
||||
while self._proc and not self._closed:
|
||||
chunk = await self._proc.stdout.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
frame = self._make_frame(MSG_DATA, payload=bytes(chunk))
|
||||
await self._send_frame(frame)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdout pump error", error=True)
|
||||
finally:
|
||||
if self._proc and not self._closed:
|
||||
try:
|
||||
self._exit_code = await self._proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
await self.stop(reason="stdout_closed")
|
||||
|
||||
async def _pump_proc_stdin(self) -> None:
|
||||
try:
|
||||
while self._proc and not self._closed:
|
||||
data = await self._stdin_queue.get()
|
||||
if self._closed or not self._proc or not self._proc.stdin:
|
||||
break
|
||||
try:
|
||||
self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
await self._proc.stdin.drain()
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
|
||||
finally:
|
||||
await self.stop(reason="stdin_closed")
|
||||
|
||||
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._proc is not None:
|
||||
try:
|
||||
self._proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
current = asyncio.current_task()
|
||||
if self._reader_task and self._reader_task is not current:
|
||||
try:
|
||||
self._reader_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
if self._writer_task and self._writer_task is not current:
|
||||
try:
|
||||
self._writer_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
# Include exit code in the close reason for debugging.
|
||||
exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else ""
|
||||
close_reason = (reason or "powershell_exit") + exit_suffix
|
||||
# Always send CLOSE before socket teardown so engine/UI see the reason.
|
||||
try:
|
||||
await self._send_close(code, close_reason)
|
||||
except Exception:
|
||||
self.role._log("reverse_tunnel ps close send failed", error=True)
|
||||
self.role._log(
|
||||
f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={close_reason}"
|
||||
)
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Namespace package for reverse tunnel domains (Agent side)."""
|
||||
|
||||
__all__ = ["remote_interactive_shell", "remote_management", "remote_video"]
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Placeholder Bash channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class BashChannel:
|
||||
"""Stub Bash handler that immediately reports unsupported."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
# Until Bash support is implemented, close the channel to free resources.
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
# Ignore payloads but acknowledge by stopping the channel to avoid leaks.
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="bash_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "bash_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel bash channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["BashChannel"]
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Expose the PowerShell channel under the domain path, with file-based import fallback."""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
powershell_module = None
|
||||
|
||||
# Attempt package-relative import first
|
||||
try: # pragma: no cover - best effort
|
||||
from ....ReverseTunnel import tunnel_Powershell as powershell_module # type: ignore
|
||||
except Exception:
|
||||
powershell_module = None
|
||||
|
||||
# Fallback: load directly from file path to survive non-package runtimes
|
||||
if powershell_module is None:
|
||||
try:
|
||||
base = Path(__file__).resolve().parents[3] / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
spec = importlib.util.spec_from_file_location("tunnel_Powershell", base)
|
||||
if spec and spec.loader:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
powershell_module = module
|
||||
except Exception:
|
||||
powershell_module = None
|
||||
|
||||
if powershell_module and hasattr(powershell_module, "PowershellChannel"):
|
||||
PowershellChannel = powershell_module.PowershellChannel # type: ignore
|
||||
else: # pragma: no cover - safety guard
|
||||
class PowershellChannel: # type: ignore
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError("PowerShell channel unavailable")
|
||||
|
||||
|
||||
__all__ = ["PowershellChannel"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for interactive shell tunnels (Agent side)."""
|
||||
|
||||
from .Powershell import PowershellChannel
|
||||
from .Bash import BashChannel
|
||||
|
||||
__all__ = ["PowershellChannel", "BashChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Interactive shell domain (PowerShell/Bash) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote interactive shell tunnel domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-interactive-shell"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder SSH channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class SSHChannel:
|
||||
"""Stub SSH handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="ssh_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "ssh_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel ssh channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["SSHChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder WinRM channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class WinRMChannel:
|
||||
"""Stub WinRM handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="winrm_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "winrm_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel winrm channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["WinRMChannel"]
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Protocol handlers for remote management tunnels (Agent side)."""
|
||||
|
||||
from .SSH import SSHChannel
|
||||
from .WinRM import WinRMChannel
|
||||
|
||||
__all__ = ["SSHChannel", "WinRMChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Remote management domain (SSH/WinRM) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote management domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-management"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder RDP channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class RDPChannel:
|
||||
"""Stub RDP handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="rdp_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "rdp_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel rdp channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["RDPChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder VNC channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class VNCChannel:
|
||||
"""Stub VNC handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="vnc_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "vnc_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel vnc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["VNCChannel"]
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Placeholder WebRTC channel (Agent side)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
MSG_DATA = 0x05
|
||||
MSG_CONTROL = 0x09
|
||||
MSG_CLOSE = 0x08
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
|
||||
|
||||
class WebRTCChannel:
|
||||
"""Stub WebRTC handler that marks the channel unsupported for now."""
|
||||
|
||||
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
|
||||
self._closed = False
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
|
||||
|
||||
async def on_frame(self, frame) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL):
|
||||
await self.stop(code=CLOSE_PROTOCOL_ERROR, reason="webrtc_unsupported")
|
||||
elif frame.msg_type == MSG_CLOSE:
|
||||
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
|
||||
|
||||
async def stop(self, code: int = CLOSE_PROTOCOL_ERROR, reason: str = "") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
try:
|
||||
await self.role._send_frame(self.tunnel, self.role.close_frame(self.channel_id, code, reason or "webrtc_closed"))
|
||||
except Exception:
|
||||
pass
|
||||
self.role._log(f"reverse_tunnel webrtc channel stopped channel={self.channel_id} reason={reason or 'closed'}")
|
||||
|
||||
|
||||
__all__ = ["WebRTCChannel"]
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Protocol handlers for remote video tunnels (Agent side)."""
|
||||
|
||||
from .WebRTC import WebRTCChannel
|
||||
from .RDP import RDPChannel
|
||||
from .VNC import VNCChannel
|
||||
|
||||
__all__ = ["WebRTCChannel", "RDPChannel", "VNCChannel"]
|
||||
@@ -1,3 +0,0 @@
|
||||
"""Remote video/desktop domain (RDP/VNC/WebRTC) handlers."""
|
||||
|
||||
__all__ = ["tunnel", "Protocols"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Placeholder module for remote video domain (Agent side)."""
|
||||
|
||||
DOMAIN_NAME = "remote-video"
|
||||
|
||||
__all__ = ["DOMAIN_NAME"]
|
||||
@@ -1,939 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
# Capture import errors for protocol handlers so we can report why they are missing.
|
||||
PS_IMPORT_ERROR: Optional[str] = None
|
||||
BASH_IMPORT_ERROR: Optional[str] = None
|
||||
tunnel_SSH = None
|
||||
tunnel_WinRM = None
|
||||
tunnel_VNC = None
|
||||
tunnel_RDP = None
|
||||
tunnel_WebRTC = None
|
||||
tunnel_Powershell = None
|
||||
tunnel_Bash = None
|
||||
|
||||
|
||||
def _load_protocol_module(module_name: str, rel_parts: list[str]) -> tuple[Optional[object], Optional[str]]:
|
||||
"""Load a protocol handler directly from a file path to survive non-package runtimes."""
|
||||
|
||||
base = Path(__file__).parent
|
||||
path = base
|
||||
for part in rel_parts:
|
||||
path = path / part
|
||||
if not path.exists():
|
||||
return None, f"path_missing:{path}"
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
if not spec or not spec.loader:
|
||||
return None, "spec_failed"
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module, None
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
return None, repr(exc)
|
||||
|
||||
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Powershell as tunnel_Powershell # type: ignore
|
||||
except Exception as exc: # pragma: no cover - best-effort logging only
|
||||
PS_IMPORT_ERROR = repr(exc)
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_Powershell",
|
||||
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Powershell.py"],
|
||||
)
|
||||
if _module:
|
||||
tunnel_Powershell = _module
|
||||
PS_IMPORT_ERROR = None
|
||||
else:
|
||||
try:
|
||||
from .ReverseTunnel import tunnel_Powershell # type: ignore # legacy fallback
|
||||
PS_IMPORT_ERROR = None
|
||||
except Exception as exc2: # pragma: no cover - diagnostic only
|
||||
PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | legacy_fallback={exc2!r} | file_load_failed={_err}"
|
||||
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_interactive_shell.Protocols import Bash as tunnel_Bash # type: ignore
|
||||
except Exception as exc: # pragma: no cover - best-effort logging only
|
||||
BASH_IMPORT_ERROR = repr(exc)
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_Bash",
|
||||
["Reverse_Tunnels", "remote_interactive_shell", "Protocols", "Bash.py"],
|
||||
)
|
||||
if _module:
|
||||
tunnel_Bash = _module
|
||||
BASH_IMPORT_ERROR = None
|
||||
else:
|
||||
BASH_IMPORT_ERROR = f"{BASH_IMPORT_ERROR} | file_load_failed={_err}"
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_management.Protocols import SSH as tunnel_SSH # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_SSH",
|
||||
["Reverse_Tunnels", "remote_management", "Protocols", "SSH.py"],
|
||||
)
|
||||
tunnel_SSH = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_management.Protocols import WinRM as tunnel_WinRM # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_WinRM",
|
||||
["Reverse_Tunnels", "remote_management", "Protocols", "WinRM.py"],
|
||||
)
|
||||
tunnel_WinRM = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import VNC as tunnel_VNC # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_VNC",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "VNC.py"],
|
||||
)
|
||||
tunnel_VNC = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import RDP as tunnel_RDP # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_RDP",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "RDP.py"],
|
||||
)
|
||||
tunnel_RDP = _module
|
||||
try:
|
||||
from .Reverse_Tunnels.remote_video.Protocols import WebRTC as tunnel_WebRTC # type: ignore
|
||||
except Exception:
|
||||
_module, _err = _load_protocol_module(
|
||||
"tunnel_WebRTC",
|
||||
["Reverse_Tunnels", "remote_video", "Protocols", "WebRTC.py"],
|
||||
)
|
||||
tunnel_WebRTC = _module
|
||||
|
||||
ROLE_NAME = "reverse_tunnel"
|
||||
ROLE_CONTEXTS = ["interactive", "system"]
|
||||
|
||||
# Message types (keep in sync with Engine service)
|
||||
MSG_CONNECT = 0x01
|
||||
MSG_CONNECT_ACK = 0x02
|
||||
MSG_CHANNEL_OPEN = 0x03
|
||||
MSG_CHANNEL_ACK = 0x04
|
||||
MSG_DATA = 0x05
|
||||
MSG_WINDOW_UPDATE = 0x06
|
||||
MSG_HEARTBEAT = 0x07
|
||||
MSG_CLOSE = 0x08
|
||||
MSG_CONTROL = 0x09
|
||||
|
||||
# Close codes
|
||||
CLOSE_OK = 0
|
||||
CLOSE_IDLE_TIMEOUT = 1
|
||||
CLOSE_GRACE_EXPIRED = 2
|
||||
CLOSE_PROTOCOL_ERROR = 3
|
||||
CLOSE_AUTH_FAILED = 4
|
||||
CLOSE_SERVER_SHUTDOWN = 5
|
||||
CLOSE_AGENT_SHUTDOWN = 6
|
||||
CLOSE_DOMAIN_LIMIT = 7
|
||||
CLOSE_UNEXPECTED_DISCONNECT = 8
|
||||
|
||||
FRAME_VERSION = 1
|
||||
FRAME_HEADER_STRUCT = struct.Struct("<BBBBII") # version, msg_type, flags, reserved, channel_id, length
|
||||
|
||||
|
||||
@dataclass
|
||||
class TunnelFrame:
|
||||
msg_type: int
|
||||
channel_id: int
|
||||
payload: bytes = field(default_factory=bytes)
|
||||
flags: int = 0
|
||||
version: int = FRAME_VERSION
|
||||
reserved: int = 0
|
||||
|
||||
def encode(self) -> bytes:
|
||||
payload = self.payload or b""
|
||||
header = FRAME_HEADER_STRUCT.pack(
|
||||
self.version,
|
||||
self.msg_type,
|
||||
self.flags,
|
||||
self.reserved,
|
||||
int(self.channel_id),
|
||||
len(payload),
|
||||
)
|
||||
return header + payload
|
||||
|
||||
|
||||
def decode_frame(raw: bytes) -> TunnelFrame:
|
||||
if len(raw) < FRAME_HEADER_STRUCT.size:
|
||||
raise ValueError("frame_too_small")
|
||||
version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0)
|
||||
if version != FRAME_VERSION:
|
||||
raise ValueError(f"unsupported_version:{version}")
|
||||
if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length:
|
||||
raise ValueError("invalid_length")
|
||||
payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length]
|
||||
if len(payload) != length:
|
||||
raise ValueError("length_mismatch")
|
||||
return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved)
|
||||
|
||||
|
||||
def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame:
|
||||
flags = 0x1 if is_ack else 0x0
|
||||
return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"")
|
||||
|
||||
|
||||
def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame:
|
||||
payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8")
|
||||
return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload)
|
||||
|
||||
|
||||
def _norm_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
try:
|
||||
return str(value).strip()
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _is_literal_ip(value: str) -> bool:
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
ipaddress.ip_address(value.strip().strip("[]"))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTunnel:
|
||||
tunnel_id: str
|
||||
domain: str
|
||||
protocol: str
|
||||
port: int
|
||||
token: str
|
||||
url: str
|
||||
heartbeat_seconds: int
|
||||
idle_seconds: int
|
||||
grace_seconds: int
|
||||
expires_at: Optional[int]
|
||||
signing_key_hint: Optional[str] = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
websocket: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
tasks: list = field(default_factory=list)
|
||||
send_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
channels: Dict[int, Any] = field(default_factory=dict)
|
||||
last_activity: float = field(default_factory=lambda: time.time())
|
||||
connected: bool = False
|
||||
stopping: bool = False
|
||||
stop_reason: Optional[str] = None
|
||||
stop_origin: Optional[str] = None
|
||||
|
||||
|
||||
class BaseChannel:
|
||||
"""Placeholder channel handler; protocol-specific handlers plug in later."""
|
||||
|
||||
def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]):
|
||||
self.role = role
|
||||
self.tunnel = tunnel
|
||||
self.channel_id = channel_id
|
||||
self.metadata = metadata or {}
|
||||
|
||||
async def start(self) -> None:
|
||||
# Nothing to prime for placeholder channels.
|
||||
return
|
||||
|
||||
async def on_frame(self, frame: TunnelFrame) -> None:
|
||||
# Drop frames until protocol module is provided.
|
||||
return
|
||||
|
||||
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
|
||||
await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason))
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, ctx):
|
||||
self.ctx = ctx
|
||||
self.sio = ctx.sio
|
||||
self.loop = ctx.loop or asyncio.get_event_loop()
|
||||
self.hooks = ctx.hooks or {}
|
||||
self._http_client_factory = self.hooks.get("http_client")
|
||||
self._log_hook = self.hooks.get("log_agent")
|
||||
self._active: Dict[str, ActiveTunnel] = {}
|
||||
self._domain_claims: Dict[str, str] = {}
|
||||
self._domain_limits: Dict[str, Optional[int]] = {
|
||||
"remote-interactive-shell": 2,
|
||||
"remote-management": 1,
|
||||
"remote-video": 2,
|
||||
# Legacy / protocol fallbacks
|
||||
"ps": 2,
|
||||
"rdp": 1,
|
||||
"vnc": 1,
|
||||
"webrtc": 2,
|
||||
"ssh": None,
|
||||
"winrm": None,
|
||||
}
|
||||
self._default_heartbeat = 20
|
||||
self._protocol_handlers: Dict[str, Any] = {}
|
||||
self._frame_cls = TunnelFrame
|
||||
self.close_frame = close_frame
|
||||
try:
|
||||
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
|
||||
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
|
||||
module_path = getattr(tunnel_Powershell, "__file__", None)
|
||||
self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}")
|
||||
else:
|
||||
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
|
||||
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
exists_hint = f" exists={module_path.exists()}"
|
||||
self._log(
|
||||
f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}",
|
||||
error=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True)
|
||||
try:
|
||||
if tunnel_Bash and hasattr(tunnel_Bash, "BashChannel"):
|
||||
self._protocol_handlers["bash"] = tunnel_Bash.BashChannel
|
||||
module_path = getattr(tunnel_Bash, "__file__", None)
|
||||
self._log(f"reverse_tunnel bash handler registered (BashChannel) module={module_path}")
|
||||
elif BASH_IMPORT_ERROR:
|
||||
self._log(f"reverse_tunnel bash handler NOT registered (missing module/class) import_error={BASH_IMPORT_ERROR}", error=True)
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel bash handler registration failed: {exc}", error=True)
|
||||
try:
|
||||
if tunnel_SSH and hasattr(tunnel_SSH, "SSHChannel"):
|
||||
self._protocol_handlers["ssh"] = tunnel_SSH.SSHChannel
|
||||
if tunnel_WinRM and hasattr(tunnel_WinRM, "WinRMChannel"):
|
||||
self._protocol_handlers["winrm"] = tunnel_WinRM.WinRMChannel
|
||||
if tunnel_VNC and hasattr(tunnel_VNC, "VNCChannel"):
|
||||
self._protocol_handlers["vnc"] = tunnel_VNC.VNCChannel
|
||||
if tunnel_RDP and hasattr(tunnel_RDP, "RDPChannel"):
|
||||
self._protocol_handlers["rdp"] = tunnel_RDP.RDPChannel
|
||||
if tunnel_WebRTC and hasattr(tunnel_WebRTC, "WebRTCChannel"):
|
||||
self._protocol_handlers["webrtc"] = tunnel_WebRTC.WebRTCChannel
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel protocol handler registration failed: {exc}", error=True)
|
||||
|
||||
# ------------------------------------------------------------------ Logging
|
||||
def _log(self, message: str, *, error: bool = False) -> None:
|
||||
fname = "reverse_tunnel.log"
|
||||
try:
|
||||
if callable(self._log_hook):
|
||||
self._log_hook(message, fname=fname)
|
||||
if error:
|
||||
self._log_hook(message, fname="agent.error.log")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ Event wiring
|
||||
def register_events(self):
|
||||
@self.sio.on("reverse_tunnel_start")
|
||||
async def _reverse_tunnel_start(payload):
|
||||
await self._handle_tunnel_start(payload)
|
||||
|
||||
@self.sio.on("reverse_tunnel_stop")
|
||||
async def _reverse_tunnel_stop(payload):
|
||||
tid = ""
|
||||
if isinstance(payload, dict):
|
||||
tid = _norm_text(payload.get("tunnel_id"))
|
||||
await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop")
|
||||
|
||||
# ------------------------------------------------------------------ Token helpers
|
||||
def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]:
|
||||
if not token:
|
||||
raise ValueError("token_missing")
|
||||
|
||||
def _b64decode(segment: str) -> bytes:
|
||||
padding = "=" * (-len(segment) % 4)
|
||||
return base64.urlsafe_b64decode(segment + padding)
|
||||
|
||||
parts = token.split(".")
|
||||
payload_segment = parts[0]
|
||||
payload_bytes = _b64decode(payload_segment)
|
||||
try:
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
raise ValueError("token_decode_error") from exc
|
||||
|
||||
if len(parts) == 2:
|
||||
candidates = []
|
||||
hint = _norm_text(signing_key_hint)
|
||||
if hint:
|
||||
candidates.append(hint)
|
||||
client = self._http_client()
|
||||
if client and hasattr(client, "load_server_signing_key"):
|
||||
try:
|
||||
stored = client.load_server_signing_key()
|
||||
except Exception:
|
||||
stored = None
|
||||
if isinstance(stored, str) and stored.strip():
|
||||
candidates.append(stored.strip())
|
||||
|
||||
signature = _b64decode(parts[1])
|
||||
verified = False
|
||||
for candidate in candidates:
|
||||
try:
|
||||
key_bytes = base64.b64decode(candidate, validate=True)
|
||||
public_key = serialization.load_der_public_key(key_bytes)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(public_key, ed25519.Ed25519PublicKey):
|
||||
continue
|
||||
try:
|
||||
public_key.verify(signature, payload_bytes)
|
||||
verified = True
|
||||
if client and hasattr(client, "store_server_signing_key"):
|
||||
try:
|
||||
client.store_server_signing_key(candidate)
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if not verified:
|
||||
raise ValueError("token_signature_invalid")
|
||||
return payload
|
||||
|
||||
def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]:
|
||||
payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint)
|
||||
|
||||
def _matches(expected: str, actual: Any) -> bool:
|
||||
return _norm_text(expected).lower() == _norm_text(actual).lower()
|
||||
|
||||
if expected_agent and not _matches(expected_agent, payload.get("agent_id")):
|
||||
raise ValueError("token_agent_mismatch")
|
||||
if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")):
|
||||
raise ValueError("token_id_mismatch")
|
||||
if expected_domain and not _matches(expected_domain, payload.get("domain")):
|
||||
raise ValueError("token_domain_mismatch")
|
||||
if expected_protocol and not _matches(expected_protocol, payload.get("protocol")):
|
||||
raise ValueError("token_protocol_mismatch")
|
||||
|
||||
expires_at = payload.get("expires_at")
|
||||
try:
|
||||
expires_ts = int(expires_at) if expires_at is not None else None
|
||||
except Exception:
|
||||
expires_ts = None
|
||||
if expires_ts is not None and expires_ts < int(time.time()):
|
||||
raise ValueError("token_expired")
|
||||
return payload
|
||||
|
||||
# ------------------------------------------------------------------ Utility
|
||||
def _http_client(self):
|
||||
try:
|
||||
if callable(self._http_client_factory):
|
||||
return self._http_client_factory()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _port_allowed(self, port: int) -> bool:
|
||||
return isinstance(port, int) and 1 <= port <= 65535
|
||||
|
||||
def _domain_allowed(self, domain: str) -> bool:
|
||||
limit = self._domain_limits.get(domain)
|
||||
if limit is None:
|
||||
return True
|
||||
active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping]
|
||||
pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active]
|
||||
count = len(set(active + pending))
|
||||
return count < limit
|
||||
|
||||
def _build_ws_url(self, port: int) -> str:
|
||||
client = self._http_client()
|
||||
if client:
|
||||
try:
|
||||
client.refresh_base_url()
|
||||
except Exception:
|
||||
pass
|
||||
base_url = getattr(client, "base_url", "") or "https://localhost:5000"
|
||||
else:
|
||||
base_url = "https://localhost:5000"
|
||||
parsed = urlparse(base_url)
|
||||
host = parsed.hostname or "localhost"
|
||||
scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws"
|
||||
return f"{scheme}://{host}:{port}/"
|
||||
|
||||
def _ssl_context(self, host: str):
|
||||
client = self._http_client()
|
||||
verify = getattr(getattr(client, "session", None), "verify", True)
|
||||
if verify is False:
|
||||
return False
|
||||
if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"):
|
||||
try:
|
||||
ctx = client.key_store.build_ssl_context()
|
||||
if ctx and _is_literal_ip(host):
|
||||
ctx.check_hostname = False
|
||||
return ctx
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _emit_status(self, payload: Dict[str, Any]) -> None:
|
||||
try:
|
||||
await self.sio.emit("reverse_tunnel_status", payload)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _mark_activity(self, tunnel: ActiveTunnel) -> None:
|
||||
tunnel.last_activity = time.time()
|
||||
|
||||
# ------------------------------------------------------------------ Event handlers
|
||||
async def _handle_tunnel_start(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
self._log("reverse_tunnel_start ignored: payload not a dict", error=True)
|
||||
return
|
||||
tunnel_id = _norm_text(payload.get("tunnel_id"))
|
||||
token = _norm_text(payload.get("token"))
|
||||
port = payload.get("port") or payload.get("assigned_port")
|
||||
protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps"
|
||||
domain = _norm_text(payload.get("domain") or protocol).lower() or protocol
|
||||
heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20)
|
||||
idle_seconds = int(payload.get("idle_seconds") or 3600)
|
||||
grace_seconds = int(payload.get("grace_seconds") or 3600)
|
||||
signing_key_hint = _norm_text(payload.get("signing_key"))
|
||||
agent_hint = _norm_text(payload.get("agent_id"))
|
||||
|
||||
# Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts).
|
||||
if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower():
|
||||
return
|
||||
|
||||
if not token:
|
||||
self._log("reverse_tunnel_start rejected: missing token", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"})
|
||||
return
|
||||
|
||||
if not tunnel_id:
|
||||
tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias
|
||||
|
||||
try:
|
||||
claims = self._validate_token(
|
||||
token,
|
||||
expected_agent=self.ctx.agent_id,
|
||||
expected_domain=domain,
|
||||
expected_protocol=protocol,
|
||||
expected_tunnel=tunnel_id,
|
||||
signing_key_hint=signing_key_hint,
|
||||
)
|
||||
except Exception as exc:
|
||||
if str(exc) == "token_agent_mismatch":
|
||||
# Broadcast hit the wrong agent context; ignore quietly.
|
||||
return
|
||||
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
|
||||
return
|
||||
|
||||
tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id)
|
||||
if not tunnel_id:
|
||||
self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True)
|
||||
return
|
||||
|
||||
try:
|
||||
port = int(port or claims.get("assigned_port") or 0)
|
||||
except Exception:
|
||||
port = 0
|
||||
if not self._port_allowed(port):
|
||||
self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"})
|
||||
return
|
||||
|
||||
domain = _norm_text(claims.get("domain") or domain).lower()
|
||||
protocol = _norm_text(claims.get("protocol") or protocol).lower()
|
||||
expires_at = claims.get("expires_at")
|
||||
|
||||
if tunnel_id in self._active:
|
||||
self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}")
|
||||
return
|
||||
if not self._domain_allowed(domain):
|
||||
self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"})
|
||||
return
|
||||
|
||||
url = self._build_ws_url(port)
|
||||
parsed = urlparse(url)
|
||||
heartbeat_seconds = max(5, min(heartbeat_seconds, 120))
|
||||
idle_seconds = max(30, idle_seconds)
|
||||
grace_seconds = max(60, grace_seconds)
|
||||
|
||||
tunnel = ActiveTunnel(
|
||||
tunnel_id=tunnel_id,
|
||||
domain=domain,
|
||||
protocol=protocol,
|
||||
port=port,
|
||||
token=token,
|
||||
url=url,
|
||||
heartbeat_seconds=heartbeat_seconds,
|
||||
idle_seconds=idle_seconds,
|
||||
grace_seconds=grace_seconds,
|
||||
expires_at=int(expires_at) if expires_at is not None else None,
|
||||
signing_key_hint=signing_key_hint or None,
|
||||
)
|
||||
|
||||
self._active[tunnel_id] = tunnel
|
||||
self._domain_claims[domain] = tunnel_id
|
||||
self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}")
|
||||
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url})
|
||||
task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost"))
|
||||
tunnel.tasks.append(task)
|
||||
|
||||
# ------------------------------------------------------------------ Core tunnel handling
|
||||
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
|
||||
ssl_ctx = self._ssl_context(host)
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
|
||||
self._log(
|
||||
f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} "
|
||||
f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}"
|
||||
)
|
||||
try:
|
||||
tunnel.session = aiohttp.ClientSession(timeout=timeout)
|
||||
tunnel.websocket = await tunnel.session.ws_connect(
|
||||
tunnel.url,
|
||||
ssl=ssl_ctx,
|
||||
heartbeat=None,
|
||||
max_msg_size=0,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._mark_activity(tunnel)
|
||||
self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}")
|
||||
await tunnel.websocket.send_bytes(
|
||||
TunnelFrame(
|
||||
msg_type=MSG_CONNECT,
|
||||
channel_id=0,
|
||||
payload=json.dumps(
|
||||
{
|
||||
"agent_id": self.ctx.agent_id,
|
||||
"tunnel_id": tunnel.tunnel_id,
|
||||
"token": tunnel.token,
|
||||
"protocol": tunnel.protocol,
|
||||
"domain": tunnel.domain,
|
||||
"version": FRAME_VERSION,
|
||||
},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8"),
|
||||
).encode()
|
||||
)
|
||||
self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
sender = self.loop.create_task(self._pump_sender(tunnel))
|
||||
receiver = self.loop.create_task(self._pump_receiver(tunnel))
|
||||
heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel))
|
||||
watchdog = self.loop.create_task(self._watchdog(tunnel))
|
||||
tunnel.tasks.extend([sender, receiver, heartbeats, watchdog])
|
||||
task_labels = {
|
||||
sender: "sender",
|
||||
receiver: "receiver",
|
||||
heartbeats: "heartbeat",
|
||||
watchdog: "watchdog",
|
||||
}
|
||||
done, pending = await asyncio.wait(task_labels.keys(), return_when=asyncio.FIRST_COMPLETED)
|
||||
for finished in done:
|
||||
label = task_labels.get(finished) or "unknown"
|
||||
exc_text = ""
|
||||
try:
|
||||
exc_obj = finished.exception()
|
||||
except asyncio.CancelledError:
|
||||
exc_obj = None
|
||||
exc_text = " (cancelled)"
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
exc_obj = exc
|
||||
if exc_obj:
|
||||
exc_text = f" (exc={exc_obj!r})"
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = f"{label}_stopped{exc_text}"
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = label
|
||||
self._log(
|
||||
f"reverse_tunnel task completed tunnel_id={tunnel.tunnel_id} task={label} stop_reason={tunnel.stop_reason}{exc_text}"
|
||||
)
|
||||
if pending:
|
||||
try:
|
||||
self._log(
|
||||
"reverse_tunnel pending tasks after first completion tunnel_id=%s pending=%s",
|
||||
# Represent pending tasks by label for debugging.
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True)
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"})
|
||||
finally:
|
||||
try:
|
||||
ws = tunnel.websocket
|
||||
self._log(
|
||||
f"reverse_tunnel ws closing tunnel_id={tunnel.tunnel_id} "
|
||||
f"code={getattr(ws, 'close_code', None)} reason={getattr(ws, 'close_reason', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await self._shutdown_tunnel(tunnel)
|
||||
|
||||
async def _pump_sender(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
frame: TunnelFrame = await tunnel.send_queue.get()
|
||||
try:
|
||||
await tunnel.websocket.send_bytes(frame.encode())
|
||||
self._mark_activity(tunnel)
|
||||
self._log(
|
||||
f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} "
|
||||
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
|
||||
)
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "sender_error"
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "sender_failed"
|
||||
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
|
||||
ws = tunnel.websocket
|
||||
if ws is None:
|
||||
return
|
||||
try:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.BINARY:
|
||||
try:
|
||||
frame = decode_frame(msg.data)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
continue
|
||||
self._mark_activity(tunnel)
|
||||
await self._handle_frame(tunnel, frame)
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
|
||||
self._log(
|
||||
f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} "
|
||||
f"code={ws.close_code} reason={ws.close_reason}"
|
||||
)
|
||||
tunnel.stop_reason = ws.close_reason or "ws_closed"
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "receiver_failed"
|
||||
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}")
|
||||
# If no stop_reason was set, emit a CLOSE so engine/UI see a reason.
|
||||
if not tunnel.stop_reason:
|
||||
try:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_UNEXPECTED_DISCONNECT, "receiver_stop"))
|
||||
tunnel.stop_reason = "receiver_stop"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
await asyncio.sleep(tunnel.heartbeat_seconds)
|
||||
await self._send_frame(tunnel, heartbeat_frame())
|
||||
self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "heartbeat_failed"
|
||||
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
|
||||
try:
|
||||
while tunnel.websocket and not tunnel.websocket.closed:
|
||||
await asyncio.sleep(10)
|
||||
now = time.time()
|
||||
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
|
||||
tunnel.stop_reason = tunnel.stop_reason or "idle_timeout"
|
||||
self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}")
|
||||
break
|
||||
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
|
||||
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
|
||||
tunnel.stop_reason = tunnel.stop_reason or "grace_expired"
|
||||
self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}")
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = "watchdog_failed"
|
||||
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
|
||||
finally:
|
||||
self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}")
|
||||
|
||||
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
self._log(
|
||||
f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} "
|
||||
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
|
||||
)
|
||||
if frame.msg_type == MSG_HEARTBEAT:
|
||||
if frame.flags & 0x1:
|
||||
self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}")
|
||||
return
|
||||
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
|
||||
return
|
||||
if frame.msg_type == MSG_CONNECT_ACK:
|
||||
tunnel.connected = True
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
|
||||
self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}")
|
||||
return
|
||||
if frame.msg_type == MSG_CHANNEL_OPEN:
|
||||
await self._handle_channel_open(tunnel, frame)
|
||||
return
|
||||
if frame.msg_type == MSG_CLOSE:
|
||||
try:
|
||||
reason = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"}
|
||||
tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None
|
||||
await self._shutdown_tunnel(tunnel, send_close=False)
|
||||
return
|
||||
|
||||
handler = tunnel.channels.get(frame.channel_id)
|
||||
if handler:
|
||||
try:
|
||||
await handler.on_frame(frame)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True)
|
||||
else:
|
||||
if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE):
|
||||
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel"))
|
||||
|
||||
async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
try:
|
||||
payload = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
payload = {}
|
||||
protocol = _norm_text(payload.get("protocol") or tunnel.protocol)
|
||||
metadata = payload.get("metadata") if isinstance(payload, dict) else {}
|
||||
|
||||
if frame.channel_id in tunnel.channels:
|
||||
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
|
||||
return
|
||||
|
||||
if protocol.lower() == "ps" and "ps" not in self._protocol_handlers:
|
||||
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
|
||||
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
|
||||
exists_hint = f" exists={module_path.exists()}"
|
||||
self._log(
|
||||
f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}",
|
||||
error=True,
|
||||
)
|
||||
|
||||
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
|
||||
try:
|
||||
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
|
||||
except Exception:
|
||||
self._log(f"reverse_tunnel channel handler fallback to BaseChannel protocol={protocol}", error=True)
|
||||
handler = BaseChannel(self, tunnel, frame.channel_id, metadata)
|
||||
tunnel.channels[frame.channel_id] = handler
|
||||
await handler.start()
|
||||
await self._send_frame(
|
||||
tunnel,
|
||||
TunnelFrame(
|
||||
msg_type=MSG_CHANNEL_ACK,
|
||||
channel_id=frame.channel_id,
|
||||
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
|
||||
),
|
||||
)
|
||||
self._log(
|
||||
f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} "
|
||||
f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}"
|
||||
)
|
||||
|
||||
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
|
||||
if tunnel.stopping and getattr(frame, "msg_type", None) != MSG_CLOSE:
|
||||
return
|
||||
try:
|
||||
tunnel.send_queue.put_nowait(frame)
|
||||
except Exception:
|
||||
await tunnel.send_queue.put(frame)
|
||||
|
||||
async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None:
|
||||
tunnel = self._active.get(tunnel_id)
|
||||
if not tunnel:
|
||||
return
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = "stop_tunnel"
|
||||
self._log(f"reverse_tunnel stop_tunnel requested tunnel_id={tunnel_id} code={code} reason={reason}")
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = reason or "requested"
|
||||
await self._send_frame(tunnel, close_frame(0, code, reason))
|
||||
await self._shutdown_tunnel(tunnel, send_close=False)
|
||||
|
||||
async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None:
|
||||
if tunnel.stopping:
|
||||
return
|
||||
tunnel.stopping = True
|
||||
reason_text = tunnel.stop_reason or "closed"
|
||||
if not tunnel.stop_reason:
|
||||
tunnel.stop_reason = reason_text
|
||||
if not tunnel.stop_origin:
|
||||
tunnel.stop_origin = "shutdown"
|
||||
self._log(
|
||||
f"reverse_tunnel shutdown start tunnel_id={tunnel.tunnel_id} stop_reason={tunnel.stop_reason} "
|
||||
f"stop_origin={tunnel.stop_origin} ws_closed={getattr(tunnel.websocket, 'closed', None)}"
|
||||
)
|
||||
# Stop all channels first so CLOSE frames (with reasons) are sent upstream.
|
||||
for handler in list(tunnel.channels.values()):
|
||||
try:
|
||||
await handler.stop(code=CLOSE_UNEXPECTED_DISCONNECT, reason=reason_text or "tunnel_shutdown")
|
||||
except Exception:
|
||||
pass
|
||||
if send_close:
|
||||
close_payload = close_frame(0, CLOSE_AGENT_SHUTDOWN, reason_text or "agent_shutdown")
|
||||
try:
|
||||
await self._send_frame(tunnel, close_payload)
|
||||
# Give the sender loop a brief window to flush the CLOSE upstream.
|
||||
await asyncio.sleep(0.05)
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback: if sender task died, try sending directly on the websocket.
|
||||
try:
|
||||
if tunnel.websocket and not tunnel.websocket.closed:
|
||||
await tunnel.websocket.send_bytes(close_payload.encode())
|
||||
except Exception:
|
||||
pass
|
||||
for task in list(tunnel.tasks):
|
||||
try:
|
||||
task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
if tunnel.websocket is not None:
|
||||
try:
|
||||
message = (reason_text or "agent_shutdown").encode("utf-8", "ignore")[:120]
|
||||
await tunnel.websocket.close(message=message)
|
||||
except Exception:
|
||||
pass
|
||||
if tunnel.session is not None:
|
||||
try:
|
||||
await tunnel.session.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._active.pop(tunnel.tunnel_id, None)
|
||||
if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id:
|
||||
self._domain_claims.pop(tunnel.domain, None)
|
||||
self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}")
|
||||
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"})
|
||||
|
||||
# ------------------------------------------------------------------ Lifecycle
|
||||
def stop_all(self):
|
||||
for tunnel_id in list(self._active.keys()):
|
||||
try:
|
||||
self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown"))
|
||||
except Exception:
|
||||
pass
|
||||
167
Data/Agent/Roles/role_VpnShell.py
Normal file
167
Data/Agent/Roles/role_VpnShell.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# ======================================================
|
||||
# Data\Agent\Roles\role_VpnShell.py
|
||||
# Description: PowerShell TCP server for VPN shell access (Engine connects over WireGuard /32).
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""VPN PowerShell server for the WireGuard tunnel."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
import os
|
||||
|
||||
ROLE_NAME = "VpnShell"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
|
||||
|
||||
def _log_path() -> Path:
|
||||
root = Path(__file__).resolve().parents[2] / "Logs"
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root / "reverse_tunnel.log"
|
||||
|
||||
|
||||
def _write_log(message: str) -> None:
|
||||
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
|
||||
try:
|
||||
_log_path().open("a", encoding="utf-8").write(f"[{ts}] [vpn-shell] {message}\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _b64encode(data: bytes) -> str:
|
||||
return base64.b64encode(data).decode("ascii").strip()
|
||||
|
||||
|
||||
def _b64decode(value: str) -> bytes:
|
||||
return base64.b64decode(value.encode("ascii"))
|
||||
|
||||
def _resolve_shell_port() -> int:
|
||||
raw = os.environ.get("BOREALIS_WIREGUARD_SHELL_PORT")
|
||||
try:
|
||||
value = int(raw) if raw is not None else 47001
|
||||
except Exception:
|
||||
value = 47001
|
||||
if value < 1 or value > 65535:
|
||||
return 47001
|
||||
return value
|
||||
|
||||
|
||||
class ShellSession:
|
||||
def __init__(self, conn: socket.socket, address: tuple[str, int]) -> None:
|
||||
self.conn = conn
|
||||
self.address = address
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
self._stop = threading.Event()
|
||||
|
||||
def start(self) -> None:
|
||||
self.proc = subprocess.Popen(
|
||||
["powershell.exe", "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
bufsize=0,
|
||||
)
|
||||
threading.Thread(target=self._reader_loop, daemon=True).start()
|
||||
self._writer_loop()
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
if not self.proc or not self.proc.stdout:
|
||||
return
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
chunk = self.proc.stdout.readline()
|
||||
if not chunk:
|
||||
break
|
||||
payload = json.dumps({"type": "stdout", "data": _b64encode(chunk)})
|
||||
self.conn.sendall(payload.encode("utf-8") + b"\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _writer_loop(self) -> None:
|
||||
buffer = b""
|
||||
try:
|
||||
while not self._stop.is_set():
|
||||
data = self.conn.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
buffer += data
|
||||
while b"\n" in buffer:
|
||||
line, buffer = buffer.split(b"\n", 1)
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
msg = json.loads(line.decode("utf-8"))
|
||||
except Exception:
|
||||
continue
|
||||
if msg.get("type") == "stdin":
|
||||
payload = msg.get("data") or ""
|
||||
if self.proc and self.proc.stdin:
|
||||
try:
|
||||
self.proc.stdin.write(_b64decode(str(payload)))
|
||||
self.proc.stdin.flush()
|
||||
except Exception:
|
||||
pass
|
||||
if msg.get("type") == "close":
|
||||
self._stop.set()
|
||||
break
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self._stop.set()
|
||||
try:
|
||||
self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
if self.proc:
|
||||
try:
|
||||
self.proc.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class ShellServer:
|
||||
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None) -> None:
|
||||
self.host = host
|
||||
self.port = port or _resolve_shell_port()
|
||||
self._thread = threading.Thread(target=self._serve, daemon=True)
|
||||
self._thread.start()
|
||||
_write_log(f"VPN shell server listening on {self.host}:{self.port}")
|
||||
|
||||
def _serve(self) -> None:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server:
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind((self.host, self.port))
|
||||
server.listen(5)
|
||||
while True:
|
||||
conn, addr = server.accept()
|
||||
remote_ip = addr[0]
|
||||
if not remote_ip.startswith("10.255."):
|
||||
_write_log(f"Rejected shell connection from {remote_ip}")
|
||||
conn.close()
|
||||
continue
|
||||
_write_log(f"Accepted shell connection from {remote_ip}")
|
||||
session = ShellSession(conn, addr)
|
||||
threading.Thread(target=session.start, daemon=True).start()
|
||||
|
||||
|
||||
class Role:
|
||||
def __init__(self, ctx) -> None:
|
||||
self.ctx = ctx
|
||||
self.server = ShellServer()
|
||||
|
||||
def register_events(self) -> None:
|
||||
return
|
||||
|
||||
def stop_all(self) -> None:
|
||||
return
|
||||
@@ -9,13 +9,15 @@
|
||||
|
||||
This role prepares the WireGuard client config, manages a single active
|
||||
session, enforces idle teardown, and logs lifecycle events to
|
||||
Agent/Logs/reverse_tunnel.log. It does not yet bind to engine signals; higher
|
||||
layers should call start_session/stop_session with the issued config/token.
|
||||
Agent/Logs/reverse_tunnel.log. It binds to Engine Socket.IO events
|
||||
(`vpn_tunnel_start`, `vpn_tunnel_stop`, `vpn_tunnel_activity`) to start/stop
|
||||
the client session with the issued config/token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
@@ -26,6 +28,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import x25519
|
||||
from signature_utils import verify_and_store_script_signature
|
||||
|
||||
ROLE_NAME = "WireGuardTunnel"
|
||||
ROLE_CONTEXTS = ["system"]
|
||||
@@ -95,6 +98,8 @@ class SessionConfig:
|
||||
allowed_ports: str
|
||||
idle_seconds: int = 900
|
||||
preshared_key: Optional[str] = None
|
||||
client_private_key: Optional[str] = None
|
||||
client_public_key: Optional[str] = None
|
||||
|
||||
|
||||
class WireGuardClient:
|
||||
@@ -122,17 +127,35 @@ class WireGuardClient:
|
||||
return candidate
|
||||
return "wireguard.exe"
|
||||
|
||||
def _validate_token(self, token: Dict[str, Any]) -> None:
|
||||
required = ("agent_id", "tunnel_id", "expires_at")
|
||||
def _validate_token(self, token: Dict[str, Any], *, signing_client: Optional[Any] = None) -> None:
|
||||
payload = dict(token or {})
|
||||
signature = payload.pop("signature", None)
|
||||
signing_key = payload.pop("signing_key", None)
|
||||
sig_alg = payload.pop("sig_alg", None)
|
||||
|
||||
required = ("agent_id", "tunnel_id", "expires_at", "port")
|
||||
missing = [field for field in required if field not in token or token[field] in ("", None)]
|
||||
if missing:
|
||||
raise ValueError(f"Missing token fields: {', '.join(missing)}")
|
||||
try:
|
||||
exp = float(token["expires_at"])
|
||||
exp = float(payload["expires_at"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid token expiry")
|
||||
if exp <= time.time():
|
||||
raise ValueError("Token expired")
|
||||
try:
|
||||
port = int(payload["port"])
|
||||
except Exception:
|
||||
raise ValueError("Invalid token port")
|
||||
if port < 1 or port > 65535:
|
||||
raise ValueError("Invalid token port")
|
||||
|
||||
if signature:
|
||||
if sig_alg and str(sig_alg).lower() not in ("ed25519", "eddsa"):
|
||||
raise ValueError("Unsupported token signature algorithm")
|
||||
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
||||
if not verify_and_store_script_signature(signing_client, payload_bytes, str(signature), signing_key):
|
||||
raise ValueError("Token signature invalid")
|
||||
|
||||
def _run(self, args: list[str]) -> tuple[int, str, str]:
|
||||
try:
|
||||
@@ -142,9 +165,10 @@ class WireGuardClient:
|
||||
return 1, "", str(exc)
|
||||
|
||||
def _render_config(self, session: SessionConfig) -> str:
|
||||
private_key = session.client_private_key or self._client_keys["private"]
|
||||
lines = [
|
||||
"[Interface]",
|
||||
f"PrivateKey = {self._client_keys['private']}",
|
||||
f"PrivateKey = {private_key}",
|
||||
f"Address = {session.virtual_ip}",
|
||||
"",
|
||||
"[Peer]",
|
||||
@@ -172,13 +196,13 @@ class WireGuardClient:
|
||||
t.start()
|
||||
self._idle_thread = t
|
||||
|
||||
def start_session(self, session: SessionConfig) -> None:
|
||||
def start_session(self, session: SessionConfig, *, signing_client: Optional[Any] = None) -> None:
|
||||
if self.session:
|
||||
_write_log("Rejecting start_session: existing session already active.")
|
||||
return
|
||||
|
||||
try:
|
||||
self._validate_token(session.token)
|
||||
self._validate_token(session.token, signing_client=signing_client)
|
||||
except Exception as exc:
|
||||
_write_log(f"Refusing to start WireGuard session: {exc}")
|
||||
return
|
||||
@@ -243,6 +267,7 @@ class Role:
|
||||
self.client = client
|
||||
hooks = getattr(ctx, "hooks", {}) or {}
|
||||
self._log_hook = hooks.get("log_agent")
|
||||
self._http_client_factory = hooks.get("http_client")
|
||||
|
||||
def _log(self, message: str, *, error: bool = False) -> None:
|
||||
if callable(self._log_hook):
|
||||
@@ -254,6 +279,14 @@ class Role:
|
||||
pass
|
||||
_write_log(message)
|
||||
|
||||
def _http_client(self) -> Optional[Any]:
|
||||
try:
|
||||
if callable(self._http_client_factory):
|
||||
return self._http_client_factory()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _build_session(self, payload: Any) -> Optional[SessionConfig]:
|
||||
if not isinstance(payload, dict):
|
||||
self._log("WireGuard start payload missing/invalid.", error=True)
|
||||
@@ -299,6 +332,8 @@ class Role:
|
||||
allowed_ports=allowed_ports,
|
||||
idle_seconds=idle_seconds,
|
||||
preshared_key=payload.get("preshared_key"),
|
||||
client_private_key=payload.get("client_private_key"),
|
||||
client_public_key=payload.get("client_public_key"),
|
||||
)
|
||||
|
||||
def register_events(self) -> None:
|
||||
@@ -310,7 +345,7 @@ class Role:
|
||||
if not session:
|
||||
return
|
||||
self._log("WireGuard start request received.")
|
||||
self.client.start_session(session)
|
||||
self.client.start_session(session, signing_client=self._http_client())
|
||||
|
||||
@sio.on("vpn_tunnel_stop")
|
||||
async def _vpn_tunnel_stop(payload):
|
||||
|
||||
Reference in New Issue
Block a user