diff --git a/Data/Agent/Roles/ReverseTunnel/__init__.py b/Data/Agent/Roles/ReverseTunnel/__init__.py new file mode 100644 index 00000000..88b181a3 --- /dev/null +++ b/Data/Agent/Roles/ReverseTunnel/__init__.py @@ -0,0 +1,2 @@ +"""Reverse tunnel protocol modules (placeholder package).""" + diff --git a/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py b/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py new file mode 100644 index 00000000..52125db8 --- /dev/null +++ b/Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py @@ -0,0 +1,226 @@ +"""PowerShell channel implementation for reverse tunnel (Agent side).""" +from __future__ import annotations + +import asyncio +import os +import sys +from typing import Any, Dict, Optional + +# Message types mirrored from the tunnel framing (kept local to avoid import cycles). +MSG_DATA = 0x05 +MSG_WINDOW_UPDATE = 0x06 +MSG_CONTROL = 0x09 +MSG_CLOSE = 0x08 + +# Close codes (mirrored from engine framing) +CLOSE_OK = 0 +CLOSE_PROTOCOL_ERROR = 3 +CLOSE_AGENT_SHUTDOWN = 6 + + +class PowershellChannel: + def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]): + self.role = role + self.tunnel = tunnel + self.channel_id = channel_id + self.metadata = metadata or {} + self.loop = getattr(role, "loop", None) or asyncio.get_event_loop() + self._closed = False + self._reader_task = None + self._writer_task = None + self._stdin_queue: asyncio.Queue = asyncio.Queue() + self._pty = None + self._exit_code: Optional[int] = None + self._frame_cls = getattr(role, "_frame_cls", None) + + # ------------------------------------------------------------------ Helpers + def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0): + frame_cls = self._frame_cls + if frame_cls is None: + return None + try: + return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags) + except Exception: + return None + + async def _send_frame(self, frame) -> None: + if frame is None: + return + await self.role._send_frame(self.tunnel, frame) + + async def _send_close(self, code: int, reason: str) -> None: + try: + close_frame = getattr(self.role, "close_frame") + if callable(close_frame): + await self._send_frame(close_frame(self.channel_id, code, reason)) + return + except Exception: + pass + frame = self._make_frame( + MSG_CLOSE, + payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"), + ) + await self._send_frame(frame) + + def _powershell_path(self) -> str: + preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None + if isinstance(preferred, str) and preferred.strip(): + return preferred.strip() + # Default to Windows PowerShell; fallback to pwsh if provided later. + return "powershell.exe" + + def _initial_size(self) -> tuple: + cols = int(self.metadata.get("cols") or self.metadata.get("columns") or 120) if isinstance(self.metadata, dict) else 120 + rows = int(self.metadata.get("rows") or 32) if isinstance(self.metadata, dict) else 32 + cols = max(20, min(cols, 300)) + rows = max(10, min(rows, 200)) + return cols, rows + + # ------------------------------------------------------------------ Lifecycle + async def start(self) -> None: + if sys.platform.lower().startswith("win") is False: + await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only") + return + try: + import pywinpty # type: ignore + except Exception as exc: # pragma: no cover - dependency guard + self.role._log(f"reverse_tunnel ps channel missing pywinpty: {exc}", error=True) + await self._send_close(CLOSE_PROTOCOL_ERROR, "pywinpty_missing") + return + + shell = self._powershell_path() + cols, rows = self._initial_size() + try: + self._pty = pywinpty.Process( + spawn_cmd=shell, + dimensions=(cols, rows), + ) + except Exception as exc: + self.role._log(f"reverse_tunnel ps channel failed to spawn {shell}: {exc}", error=True) + await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed") + return + + self._reader_task = self.loop.create_task(self._pump_stdout()) + self._writer_task = self.loop.create_task(self._pump_stdin()) + self.role._log(f"reverse_tunnel ps channel started shell={shell} cols={cols} rows={rows}") + + async def on_frame(self, frame) -> None: + if self._closed: + return + if frame.msg_type == MSG_DATA: + if frame.payload: + try: + self._stdin_queue.put_nowait(frame.payload) + except Exception: + await self._stdin_queue.put(frame.payload) + elif frame.msg_type == MSG_CONTROL: + await self._handle_control(frame.payload) + elif frame.msg_type == MSG_CLOSE: + await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close") + elif frame.msg_type == MSG_WINDOW_UPDATE: + # Reserved for back-pressure; ignore for now. + return + + async def _handle_control(self, payload: bytes) -> None: + try: + import json + + data = json.loads(payload.decode("utf-8")) + except Exception: + return + cols = data.get("cols") or data.get("columns") + rows = data.get("rows") + if cols is None and rows is None: + return + try: + cols_int = int(cols) if cols is not None else None + rows_int = int(rows) if rows is not None else None + except Exception: + return + await self._resize(cols_int, rows_int) + + async def _resize(self, cols: Optional[int], rows: Optional[int]) -> None: + if self._pty is None: + return + try: + cur_cols, cur_rows = self._initial_size() + if cols is None: + cols = cur_cols + if rows is None: + rows = cur_rows + cols = max(20, min(int(cols), 300)) + rows = max(10, min(int(rows), 200)) + self._pty.set_size(cols, rows) + self.role._log(f"reverse_tunnel ps channel resized cols={cols} rows={rows}") + except Exception: + self.role._log("reverse_tunnel ps channel resize failed", error=True) + + async def _pump_stdout(self) -> None: + loop = asyncio.get_event_loop() + try: + while not self._closed and self._pty: + chunk = await loop.run_in_executor(None, self._pty.read, 4096) + if chunk is None: + break + if isinstance(chunk, str): + data = chunk.encode("utf-8", errors="replace") + else: + data = bytes(chunk) + if not data: + break + frame = self._make_frame(MSG_DATA, payload=data) + await self._send_frame(frame) + except asyncio.CancelledError: + pass + except Exception: + self.role._log("reverse_tunnel ps stdout pump error", error=True) + finally: + await self.stop(reason="stdout_closed") + + async def _pump_stdin(self) -> None: + loop = asyncio.get_event_loop() + try: + while not self._closed and self._pty: + try: + data = await self._stdin_queue.get() + except asyncio.CancelledError: + break + if data is None: + break + if isinstance(data, (bytes, bytearray)): + text = data.decode("utf-8", errors="replace") + else: + text = str(data) + try: + await loop.run_in_executor(None, self._pty.write, text) + except Exception: + break + except asyncio.CancelledError: + pass + except Exception: + self.role._log("reverse_tunnel ps stdin pump error", error=True) + finally: + await self.stop(reason="stdin_closed") + + async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None: + if self._closed: + return + self._closed = True + if self._pty is not None: + try: + self._pty.terminate() + except Exception: + pass + current = asyncio.current_task() + if self._reader_task and self._reader_task is not current: + try: + self._reader_task.cancel() + except Exception: + pass + if self._writer_task and self._writer_task is not current: + try: + self._writer_task.cancel() + except Exception: + pass + await self._send_close(code, reason or "powershell_exit") + self.role._log(f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={reason or 'exit'}") diff --git a/Data/Agent/Roles/role_ReverseTunnel.py b/Data/Agent/Roles/role_ReverseTunnel.py new file mode 100644 index 00000000..8161b00a --- /dev/null +++ b/Data/Agent/Roles/role_ReverseTunnel.py @@ -0,0 +1,654 @@ +import asyncio +import base64 +import json +import os +import struct +import time +from dataclasses import dataclass, field +from typing import Any, Dict, Optional +from urllib.parse import urlparse + +import aiohttp +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 + +try: + from .ReverseTunnel import tunnel_Powershell +except Exception: + tunnel_Powershell = None + +ROLE_NAME = "reverse_tunnel" +ROLE_CONTEXTS = ["interactive", "system"] + +# Message types (keep in sync with Engine service) +MSG_CONNECT = 0x01 +MSG_CONNECT_ACK = 0x02 +MSG_CHANNEL_OPEN = 0x03 +MSG_CHANNEL_ACK = 0x04 +MSG_DATA = 0x05 +MSG_WINDOW_UPDATE = 0x06 +MSG_HEARTBEAT = 0x07 +MSG_CLOSE = 0x08 +MSG_CONTROL = 0x09 + +# Close codes +CLOSE_OK = 0 +CLOSE_IDLE_TIMEOUT = 1 +CLOSE_GRACE_EXPIRED = 2 +CLOSE_PROTOCOL_ERROR = 3 +CLOSE_AUTH_FAILED = 4 +CLOSE_SERVER_SHUTDOWN = 5 +CLOSE_AGENT_SHUTDOWN = 6 +CLOSE_DOMAIN_LIMIT = 7 +CLOSE_UNEXPECTED_DISCONNECT = 8 + +FRAME_VERSION = 1 +FRAME_HEADER_STRUCT = struct.Struct(" bytes: + payload = self.payload or b"" + header = FRAME_HEADER_STRUCT.pack( + self.version, + self.msg_type, + self.flags, + self.reserved, + int(self.channel_id), + len(payload), + ) + return header + payload + + +def decode_frame(raw: bytes) -> TunnelFrame: + if len(raw) < FRAME_HEADER_STRUCT.size: + raise ValueError("frame_too_small") + version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0) + if version != FRAME_VERSION: + raise ValueError(f"unsupported_version:{version}") + if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length: + raise ValueError("invalid_length") + payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length] + if len(payload) != length: + raise ValueError("length_mismatch") + return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved) + + +def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame: + flags = 0x1 if is_ack else 0x0 + return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"") + + +def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame: + payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8") + return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload) + + +def _norm_text(value: Any) -> str: + if value is None: + return "" + try: + return str(value).strip() + except Exception: + return "" + + +def _is_literal_ip(value: str) -> bool: + try: + import ipaddress + + ipaddress.ip_address(value.strip().strip("[]")) + return True + except Exception: + return False + + +@dataclass +class ActiveTunnel: + tunnel_id: str + domain: str + protocol: str + port: int + token: str + url: str + heartbeat_seconds: int + idle_seconds: int + grace_seconds: int + expires_at: Optional[int] + signing_key_hint: Optional[str] = None + session: Optional[aiohttp.ClientSession] = None + websocket: Optional[aiohttp.ClientWebSocketResponse] = None + tasks: list = field(default_factory=list) + send_queue: asyncio.Queue = field(default_factory=asyncio.Queue) + channels: Dict[int, Any] = field(default_factory=dict) + last_activity: float = field(default_factory=lambda: time.time()) + connected: bool = False + stopping: bool = False + stop_reason: Optional[str] = None + + +class BaseChannel: + """Placeholder channel handler; protocol-specific handlers plug in later.""" + + def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]): + self.role = role + self.tunnel = tunnel + self.channel_id = channel_id + self.metadata = metadata or {} + + async def start(self) -> None: + # Nothing to prime for placeholder channels. + return + + async def on_frame(self, frame: TunnelFrame) -> None: + # Drop frames until protocol module is provided. + return + + async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None: + await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason)) + + +class Role: + def __init__(self, ctx): + self.ctx = ctx + self.sio = ctx.sio + self.loop = ctx.loop or asyncio.get_event_loop() + self.hooks = ctx.hooks or {} + self._http_client_factory = self.hooks.get("http_client") + self._log_hook = self.hooks.get("log_agent") + self._active: Dict[str, ActiveTunnel] = {} + self._domain_claims: Dict[str, str] = {} + self._domain_limits: Dict[str, Optional[int]] = { + "ps": 1, + "rdp": 1, + "vnc": 1, + "webrtc": 1, + "ssh": None, + "winrm": None, + } + self._default_heartbeat = 20 + self._protocol_handlers: Dict[str, Any] = {} + self._frame_cls = TunnelFrame + self.close_frame = close_frame + if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"): + self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel + + # ------------------------------------------------------------------ Logging + def _log(self, message: str, *, error: bool = False) -> None: + fname = "reverse_tunnel.log" + try: + if callable(self._log_hook): + self._log_hook(message, fname=fname) + if error: + self._log_hook(message, fname="agent.error.log") + except Exception: + pass + + # ------------------------------------------------------------------ Event wiring + def register_events(self): + @self.sio.on("reverse_tunnel_start") + async def _reverse_tunnel_start(payload): + await self._handle_tunnel_start(payload) + + @self.sio.on("reverse_tunnel_stop") + async def _reverse_tunnel_stop(payload): + tid = "" + if isinstance(payload, dict): + tid = _norm_text(payload.get("tunnel_id")) + await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop") + + # ------------------------------------------------------------------ Token helpers + def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]: + if not token: + raise ValueError("token_missing") + + def _b64decode(segment: str) -> bytes: + padding = "=" * (-len(segment) % 4) + return base64.urlsafe_b64decode(segment + padding) + + parts = token.split(".") + payload_segment = parts[0] + payload_bytes = _b64decode(payload_segment) + try: + payload = json.loads(payload_bytes.decode("utf-8")) + except Exception as exc: + raise ValueError("token_decode_error") from exc + + if len(parts) == 2: + candidates = [] + hint = _norm_text(signing_key_hint) + if hint: + candidates.append(hint) + client = self._http_client() + if client and hasattr(client, "load_server_signing_key"): + try: + stored = client.load_server_signing_key() + except Exception: + stored = None + if isinstance(stored, str) and stored.strip(): + candidates.append(stored.strip()) + + signature = _b64decode(parts[1]) + verified = False + for candidate in candidates: + try: + key_bytes = base64.b64decode(candidate, validate=True) + public_key = serialization.load_der_public_key(key_bytes) + except Exception: + continue + if not isinstance(public_key, ed25519.Ed25519PublicKey): + continue + try: + public_key.verify(signature, payload_bytes) + verified = True + if client and hasattr(client, "store_server_signing_key"): + try: + client.store_server_signing_key(candidate) + except Exception: + pass + break + except Exception: + continue + if not verified: + raise ValueError("token_signature_invalid") + return payload + + def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]: + payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint) + + def _matches(expected: str, actual: Any) -> bool: + return _norm_text(expected).lower() == _norm_text(actual).lower() + + if expected_agent and not _matches(expected_agent, payload.get("agent_id")): + raise ValueError("token_agent_mismatch") + if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")): + raise ValueError("token_id_mismatch") + if expected_domain and not _matches(expected_domain, payload.get("domain")): + raise ValueError("token_domain_mismatch") + if expected_protocol and not _matches(expected_protocol, payload.get("protocol")): + raise ValueError("token_protocol_mismatch") + + expires_at = payload.get("expires_at") + try: + expires_ts = int(expires_at) if expires_at is not None else None + except Exception: + expires_ts = None + if expires_ts is not None and expires_ts < int(time.time()): + raise ValueError("token_expired") + return payload + + # ------------------------------------------------------------------ Utility + def _http_client(self): + try: + if callable(self._http_client_factory): + return self._http_client_factory() + except Exception: + return None + return None + + def _port_allowed(self, port: int) -> bool: + return isinstance(port, int) and 1 <= port <= 65535 + + def _domain_allowed(self, domain: str) -> bool: + limit = self._domain_limits.get(domain) + if limit is None: + return True + active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping] + pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active] + count = len(set(active + pending)) + return count < limit + + def _build_ws_url(self, port: int) -> str: + client = self._http_client() + if client: + try: + client.refresh_base_url() + except Exception: + pass + base_url = getattr(client, "base_url", "") or "https://localhost:5000" + else: + base_url = "https://localhost:5000" + parsed = urlparse(base_url) + host = parsed.hostname or "localhost" + scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws" + return f"{scheme}://{host}:{port}/" + + def _ssl_context(self, host: str): + client = self._http_client() + verify = getattr(getattr(client, "session", None), "verify", True) + if verify is False: + return False + if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"): + try: + ctx = client.key_store.build_ssl_context() + if ctx and _is_literal_ip(host): + ctx.check_hostname = False + return ctx + except Exception: + return None + return None + + async def _emit_status(self, payload: Dict[str, Any]) -> None: + try: + await self.sio.emit("reverse_tunnel_status", payload) + except Exception: + pass + + def _mark_activity(self, tunnel: ActiveTunnel) -> None: + tunnel.last_activity = time.time() + + # ------------------------------------------------------------------ Event handlers + async def _handle_tunnel_start(self, payload: Any) -> None: + if not isinstance(payload, dict): + self._log("reverse_tunnel_start ignored: payload not a dict", error=True) + return + tunnel_id = _norm_text(payload.get("tunnel_id")) + token = _norm_text(payload.get("token")) + port = payload.get("port") or payload.get("assigned_port") + protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps" + domain = _norm_text(payload.get("domain") or protocol).lower() or protocol + heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20) + idle_seconds = int(payload.get("idle_seconds") or 3600) + grace_seconds = int(payload.get("grace_seconds") or 3600) + signing_key_hint = _norm_text(payload.get("signing_key")) + + if not token: + self._log("reverse_tunnel_start rejected: missing token", error=True) + await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"}) + return + + if not tunnel_id: + tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias + + try: + claims = self._validate_token( + token, + expected_agent=self.ctx.agent_id, + expected_domain=domain, + expected_protocol=protocol, + expected_tunnel=tunnel_id, + signing_key_hint=signing_key_hint, + ) + except Exception as exc: + self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True) + await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"}) + return + + tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id) + if not tunnel_id: + self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True) + return + + try: + port = int(port or claims.get("assigned_port") or 0) + except Exception: + port = 0 + if not self._port_allowed(port): + self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True) + await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"}) + return + + domain = _norm_text(claims.get("domain") or domain).lower() + protocol = _norm_text(claims.get("protocol") or protocol).lower() + expires_at = claims.get("expires_at") + + if tunnel_id in self._active: + self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}") + return + if not self._domain_allowed(domain): + self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True) + await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"}) + return + + url = self._build_ws_url(port) + parsed = urlparse(url) + heartbeat_seconds = max(5, min(heartbeat_seconds, 120)) + idle_seconds = max(30, idle_seconds) + grace_seconds = max(60, grace_seconds) + + tunnel = ActiveTunnel( + tunnel_id=tunnel_id, + domain=domain, + protocol=protocol, + port=port, + token=token, + url=url, + heartbeat_seconds=heartbeat_seconds, + idle_seconds=idle_seconds, + grace_seconds=grace_seconds, + expires_at=int(expires_at) if expires_at is not None else None, + signing_key_hint=signing_key_hint or None, + ) + + self._active[tunnel_id] = tunnel + self._domain_claims[domain] = tunnel_id + self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}") + await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url}) + task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost")) + tunnel.tasks.append(task) + + # ------------------------------------------------------------------ Core tunnel handling + async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None: + ssl_ctx = self._ssl_context(host) + timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None) + try: + tunnel.session = aiohttp.ClientSession(timeout=timeout) + tunnel.websocket = await tunnel.session.ws_connect( + tunnel.url, + ssl=ssl_ctx, + heartbeat=None, + max_msg_size=0, + timeout=timeout, + ) + self._mark_activity(tunnel) + await tunnel.websocket.send_bytes( + TunnelFrame( + msg_type=MSG_CONNECT, + channel_id=0, + payload=json.dumps( + { + "agent_id": self.ctx.agent_id, + "tunnel_id": tunnel.tunnel_id, + "token": tunnel.token, + "protocol": tunnel.protocol, + "domain": tunnel.domain, + "version": FRAME_VERSION, + }, + separators=(",", ":"), + ).encode("utf-8"), + ).encode() + ) + + sender = self.loop.create_task(self._pump_sender(tunnel)) + receiver = self.loop.create_task(self._pump_receiver(tunnel)) + heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel)) + watchdog = self.loop.create_task(self._watchdog(tunnel)) + tunnel.tasks.extend([sender, receiver, heartbeats, watchdog]) + await asyncio.wait([sender, receiver, heartbeats, watchdog], return_when=asyncio.FIRST_COMPLETED) + except Exception as exc: + self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True) + await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"}) + finally: + await self._shutdown_tunnel(tunnel) + + async def _pump_sender(self, tunnel: ActiveTunnel) -> None: + try: + while tunnel.websocket and not tunnel.websocket.closed: + frame: TunnelFrame = await tunnel.send_queue.get() + try: + await tunnel.websocket.send_bytes(frame.encode()) + self._mark_activity(tunnel) + except Exception: + break + except asyncio.CancelledError: + pass + except Exception: + self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True) + + async def _pump_receiver(self, tunnel: ActiveTunnel) -> None: + ws = tunnel.websocket + if ws is None: + return + try: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.BINARY: + try: + frame = decode_frame(msg.data) + except Exception: + self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True) + continue + self._mark_activity(tunnel) + await self._handle_frame(tunnel, frame) + elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE): + break + except asyncio.CancelledError: + pass + except Exception: + self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True) + + async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None: + try: + while tunnel.websocket and not tunnel.websocket.closed: + await asyncio.sleep(tunnel.heartbeat_seconds) + await self._send_frame(tunnel, heartbeat_frame()) + except asyncio.CancelledError: + pass + except Exception: + self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True) + + async def _watchdog(self, tunnel: ActiveTunnel) -> None: + try: + while tunnel.websocket and not tunnel.websocket.closed: + await asyncio.sleep(10) + now = time.time() + if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds: + await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout")) + break + if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds: + await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired")) + break + except asyncio.CancelledError: + pass + except Exception: + self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True) + + async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: + if frame.msg_type == MSG_HEARTBEAT: + if frame.flags & 0x1: + return + await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True)) + return + if frame.msg_type == MSG_CONNECT_ACK: + tunnel.connected = True + await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"}) + return + if frame.msg_type == MSG_CHANNEL_OPEN: + await self._handle_channel_open(tunnel, frame) + return + if frame.msg_type == MSG_CLOSE: + try: + reason = json.loads(frame.payload.decode("utf-8")) + except Exception: + reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"} + tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None + await self._shutdown_tunnel(tunnel, send_close=False) + return + + handler = tunnel.channels.get(frame.channel_id) + if handler: + try: + await handler.on_frame(frame) + except Exception: + self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True) + else: + if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE): + await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel")) + + async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: + try: + payload = json.loads(frame.payload.decode("utf-8")) + except Exception: + payload = {} + protocol = _norm_text(payload.get("protocol") or tunnel.protocol) + metadata = payload.get("metadata") if isinstance(payload, dict) else {} + + if frame.channel_id in tunnel.channels: + await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists")) + return + + handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel + try: + handler = handler_cls(self, tunnel, frame.channel_id, metadata) + except Exception: + handler = BaseChannel(self, tunnel, frame.channel_id, metadata) + tunnel.channels[frame.channel_id] = handler + await handler.start() + await self._send_frame( + tunnel, + TunnelFrame( + msg_type=MSG_CHANNEL_ACK, + channel_id=frame.channel_id, + payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"), + ), + ) + + async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None: + if tunnel.stopping: + return + try: + tunnel.send_queue.put_nowait(frame) + except Exception: + await tunnel.send_queue.put(frame) + + async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None: + tunnel = self._active.get(tunnel_id) + if not tunnel: + return + await self._send_frame(tunnel, close_frame(0, code, reason)) + await self._shutdown_tunnel(tunnel, send_close=False) + + async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None: + if tunnel.stopping: + return + tunnel.stopping = True + if send_close: + try: + await self._send_frame(tunnel, close_frame(0, CLOSE_AGENT_SHUTDOWN, "agent_shutdown")) + except Exception: + pass + for task in list(tunnel.tasks): + try: + task.cancel() + except Exception: + pass + if tunnel.websocket is not None: + try: + await tunnel.websocket.close() + except Exception: + pass + if tunnel.session is not None: + try: + await tunnel.session.close() + except Exception: + pass + self._active.pop(tunnel.tunnel_id, None) + if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id: + self._domain_claims.pop(tunnel.domain, None) + self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}") + await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"}) + + # ------------------------------------------------------------------ Lifecycle + def stop_all(self): + for tunnel_id in list(self._active.keys()): + try: + self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown")) + except Exception: + pass diff --git a/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py b/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py index 7dbf1ce0..5bda734f 100644 --- a/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py +++ b/Data/Engine/services/WebSocket/Agent/ReverseTunnel.py @@ -29,6 +29,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple from collections import deque from threading import Thread +from .ReverseTunnel.Powershell import PowershellChannelServer + try: # websockets is added to engine requirements import websockets from websockets.server import serve as ws_serve @@ -447,6 +449,7 @@ class ReverseTunnelService: self._bridges: Dict[str, "TunnelBridge"] = {} self._port_servers: Dict[int, asyncio.AbstractServer] = {} self._agent_sockets: Dict[str, "websockets.WebSocketServerProtocol"] = {} + self._ps_servers: Dict[str, PowershellChannelServer] = {} def _ensure_loop(self) -> None: if self._running and self._loop: @@ -548,6 +551,13 @@ class ReverseTunnelService: raise ValueError("unknown_tunnel") bridge = self.ensure_bridge(lease) bridge.attach_operator(operator_id) + if lease.domain.lower() == "ps": + try: + server = self.ensure_ps_server(tunnel_id) + if server: + server.open_channel() + except Exception: + self.logger.debug("ps server open failed tunnel_id=%s", tunnel_id, exc_info=True) return bridge def _encode_token(self, payload: Dict[str, object]) -> str: @@ -841,6 +851,15 @@ class ReverseTunnelService: except Exception: pass + def _dispatch_agent_frame(self, tunnel_id: str, frame: TunnelFrame) -> None: + server = self._ps_servers.get(tunnel_id) + if not server: + return + try: + server.handle_agent_frame(frame) + except Exception: + self.logger.debug("ps handler error for tunnel_id=%s", tunnel_id, exc_info=True) + def _start_lease_sweeper(self) -> None: async def _sweeper(): while self._running and self._loop and not self._loop.is_closed(): @@ -928,6 +947,10 @@ class ReverseTunnelService: except Exception: continue self.lease_manager.touch(tunnel_id) + try: + self._dispatch_agent_frame(tunnel_id, recv_frame) + except Exception: + pass bridge.agent_to_operator(recv_frame) async def _pump_to_agent(): while not websocket.closed: @@ -969,10 +992,30 @@ class ReverseTunnelService: self._bridges[lease.tunnel_id] = bridge return bridge + def ensure_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]: + server = self._ps_servers.get(tunnel_id) + if server: + return server + lease = self.lease_manager.get(tunnel_id) + if lease is None or (lease.domain or "").lower() != "ps": + return None + bridge = self.ensure_bridge(lease) + server = PowershellChannelServer(bridge=bridge, service=self) + self._ps_servers[tunnel_id] = server + return server + + def get_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]: + return self._ps_servers.get(tunnel_id) + def release_bridge(self, tunnel_id: str, *, reason: str = "bridge_released") -> None: bridge = self._bridges.pop(tunnel_id, None) if bridge: bridge.stop(reason=reason) + if tunnel_id in self._ps_servers: + try: + self._ps_servers.pop(tunnel_id, None) + except Exception: + pass class TunnelBridge: diff --git a/Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py b/Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py new file mode 100644 index 00000000..72d2b82d --- /dev/null +++ b/Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py @@ -0,0 +1,130 @@ +"""Engine-side PowerShell tunnel channel helper.""" +from __future__ import annotations + +import json +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +from ..ReverseTunnel import ( + CLOSE_AGENT_SHUTDOWN, + CLOSE_OK, + CLOSE_PROTOCOL_ERROR, + MSG_CHANNEL_ACK, + MSG_CHANNEL_OPEN, + MSG_CLOSE, + MSG_CONTROL, + MSG_DATA, + TunnelFrame, + close_frame, +) + + +class PowershellChannelServer: + """Coordinate PowerShell channel frames over a TunnelBridge.""" + + def __init__(self, bridge, service, *, channel_id: int = 1): + self.bridge = bridge + self.service = service + self.channel_id = channel_id + self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}") + self._open_sent = False + self._ack_received = False + self._closed = False + self._output: Deque[str] = deque() + self._close_reason: Optional[str] = None + self._close_code: Optional[int] = None + + # ------------------------------------------------------------------ Agent frame handling + def handle_agent_frame(self, frame: TunnelFrame) -> None: + if frame.channel_id != self.channel_id: + return + if frame.msg_type == MSG_CHANNEL_ACK: + self._ack_received = True + self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id) + return + if frame.msg_type == MSG_DATA: + try: + text = frame.payload.decode("utf-8", errors="replace") + except Exception: + text = "" + if text: + self._append_output(text) + return + if frame.msg_type == MSG_CLOSE: + try: + payload = json.loads(frame.payload.decode("utf-8")) + except Exception: + payload = {} + self._closed = True + self._close_code = payload.get("code") if isinstance(payload, dict) else None + self._close_reason = payload.get("reason") if isinstance(payload, dict) else None + self.logger.info( + "ps channel closed tunnel_id=%s code=%s reason=%s", + self.bridge.lease.tunnel_id, + self._close_code, + self._close_reason or "-", + ) + + # ------------------------------------------------------------------ Operator actions + def open_channel(self, *, cols: int = 120, rows: int = 32) -> None: + if self._open_sent: + return + payload = json.dumps( + {"protocol": "ps", "metadata": {"cols": cols, "rows": rows}}, + separators=(",", ":"), + ).encode("utf-8") + frame = TunnelFrame(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload) + self.bridge.operator_to_agent(frame) + self._open_sent = True + self.logger.info( + "ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s", + self.bridge.lease.tunnel_id, + self.channel_id, + cols, + rows, + ) + + def send_input(self, data: str) -> None: + if self._closed: + return + payload = data.encode("utf-8", errors="replace") + frame = TunnelFrame(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload) + self.bridge.operator_to_agent(frame) + + def send_resize(self, cols: int, rows: int) -> None: + if self._closed: + return + payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8") + frame = TunnelFrame(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload) + self.bridge.operator_to_agent(frame) + + def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None: + if self._closed: + return + self._closed = True + self.bridge.operator_to_agent(close_frame(self.channel_id, code, reason)) + + # ------------------------------------------------------------------ Output polling + def drain_output(self) -> List[str]: + items: List[str] = [] + while self._output: + items.append(self._output.popleft()) + return items + + def _append_output(self, text: str) -> None: + self._output.append(text) + # Cap buffer to avoid unbounded memory growth. + while len(self._output) > 500: + self._output.popleft() + + # ------------------------------------------------------------------ Status helpers + def status(self) -> Dict[str, Any]: + return { + "channel_id": self.channel_id, + "open_sent": self._open_sent, + "ack": self._ack_received, + "closed": self._closed, + "close_reason": self._close_reason, + "close_code": self._close_code, + } + diff --git a/Data/Engine/services/WebSocket/Agent/ReverseTunnel/__init__.py b/Data/Engine/services/WebSocket/Agent/ReverseTunnel/__init__.py new file mode 100644 index 00000000..db1c9337 --- /dev/null +++ b/Data/Engine/services/WebSocket/Agent/ReverseTunnel/__init__.py @@ -0,0 +1,2 @@ +"""Protocol-specific helpers for Reverse Tunnel (Engine side).""" + diff --git a/Data/Engine/services/WebSocket/__init__.py b/Data/Engine/services/WebSocket/__init__.py index 174f9148..632729c1 100644 --- a/Data/Engine/services/WebSocket/__init__.py +++ b/Data/Engine/services/WebSocket/__init__.py @@ -401,6 +401,91 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None: frames.append(_encode_frame(frame)) return {"frames": frames} + def _require_ps_server(): + sid = request.sid + tunnel_id = _operator_sessions.get(sid) + if not tunnel_id: + return None, None, {"error": "not_joined"} + server = tunnel_service.ensure_ps_server(tunnel_id) + if server is None: + return None, tunnel_id, {"error": "ps_unsupported"} + return server, tunnel_id, None + + @socket_server.on("ps_open", namespace=tunnel_namespace) + def _ws_ps_open(data: Any) -> Any: + server, tunnel_id, error = _require_ps_server() + if server is None: + return error + cols = 120 + rows = 32 + if isinstance(data, dict): + try: + cols = int(data.get("cols", cols)) + rows = int(data.get("rows", rows)) + except Exception: + pass + cols = max(20, min(cols, 300)) + rows = max(10, min(rows, 200)) + try: + server.open_channel(cols=cols, rows=rows) + except Exception as exc: + logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) + return {"error": "ps_open_failed"} + return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows} + + @socket_server.on("ps_send", namespace=tunnel_namespace) + def _ws_ps_send(data: Any) -> Any: + server, tunnel_id, error = _require_ps_server() + if server is None: + return error + if data is None: + return {"error": "payload_required"} + text = data + if isinstance(data, dict): + text = data.get("data") + if text is None: + return {"error": "payload_required"} + try: + server.send_input(str(text)) + except Exception as exc: + logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) + return {"error": "ps_send_failed"} + return {"status": "ok"} + + @socket_server.on("ps_resize", namespace=tunnel_namespace) + def _ws_ps_resize(data: Any) -> Any: + server, tunnel_id, error = _require_ps_server() + if server is None: + return error + cols = None + rows = None + if isinstance(data, dict): + cols = data.get("cols") + rows = data.get("rows") + try: + cols_int = int(cols) if cols is not None else 120 + rows_int = int(rows) if rows is not None else 32 + cols_int = max(20, min(cols_int, 300)) + rows_int = max(10, min(rows_int, 200)) + server.send_resize(cols_int, rows_int) + return {"status": "ok", "cols": cols_int, "rows": rows_int} + except Exception as exc: + logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) + return {"error": "ps_resize_failed"} + + @socket_server.on("ps_poll", namespace=tunnel_namespace) + def _ws_ps_poll() -> Any: + server, tunnel_id, error = _require_ps_server() + if server is None: + return error + try: + output = server.drain_output() + status = server.status() + return {"output": output, "status": status} + except Exception as exc: + logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True) + return {"error": "ps_poll_failed"} + @socket_server.on("disconnect", namespace=tunnel_namespace) def _ws_tunnel_disconnect(): sid = request.sid diff --git a/Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md b/Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md index eeb422f5..a509dc00 100644 --- a/Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md +++ b/Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md @@ -206,17 +206,17 @@ Read `Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md` and - [x] Implement channel framing, flow control, heartbeats, close semantics. - [x] Logging: `Engine/Logs/reverse_tunnel.log`; audit into Device Activity (session start/stop, operator id, agent id, tunnel_id, port). - [x] WebUI operator bridge endpoint (WebSocket) that maps browser sessions to agent channels. - - [x] Idle/grace sweeper + heartbeat wiring for tunnel sockets. - - [x] TLS-aware per-port listener and agent CONNECT_ACK handling. -- [ ] Agent tunnel role - - [ ] Add `Data/Agent/Roles/role_ReverseTunnel.py` (manages tunnel socket, reconnect, heartbeats, channel dispatch). - - [ ] Per-protocol submodules under `Data/Agent/Roles/ReverseTunnel/` (first: `tunnel_Powershell.py`). - - [ ] Enforce per-domain concurrency (one PowerShell; prevent multiple RDP/VNC/WebRTC; allow extensible policies). - - [ ] Logging: `Agent/Logs/reverse_tunnel.log`; include tunnel_id/channel_id. - - [ ] Integrate token validation, TLS reuse, idle teardown, and graceful stop_all. +- [x] Idle/grace sweeper + heartbeat wiring for tunnel sockets. +- [x] TLS-aware per-port listener and agent CONNECT_ACK handling. +- [x] Agent tunnel role + - [x] Add `Data/Agent/Roles/role_ReverseTunnel.py` (manages tunnel socket, reconnect, heartbeats, channel dispatch). + - [x] Per-protocol submodules under `Data/Agent/Roles/ReverseTunnel/` (first: `tunnel_Powershell.py`). + - [x] Enforce per-domain concurrency (one PowerShell; prevent multiple RDP/VNC/WebRTC; allow extensible policies). + - [x] Logging: `Agent/Logs/reverse_tunnel.log`; include tunnel_id/channel_id. + - [x] Integrate token validation, TLS reuse, idle teardown, and graceful stop_all. - [ ] PowerShell v1 (feature target) - - [ ] Engine side `Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py` (channel server, resize handling, translate browser events). - - [ ] Agent side `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` using ConPTY/pywinpty; map stdin/stdout to frames; handle resize and exit codes. + - [x] Engine side `Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py` (channel server, resize handling, translate browser events). + - [x] Agent side `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` using ConPTY/pywinpty; map stdin/stdout to frames; handle resize and exit codes. - [ ] WebUI: `Data/Engine/web-interface/src/ReverseTunnel/Powershell.jsx` with terminal UI, syntax highlighting matching `Assemblies/Assembly_Editor.jsx`, copy support, status toasts. - [ ] Device Activity entries and UI surface in `Devices/Device_List.jsx` Device Activity tab. - [ ] Credits & attribution @@ -244,6 +244,8 @@ Read `Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md` and - 2025-11-30: Added WebUI-facing Socket.IO namespace `/tunnel` with join/send/poll events that map browser sessions to tunnel bridges, using base64-encoded frames and operator auth from session/cookies. - 2025-11-30: Enabled async WebSocket listener per assigned port (TLS-aware via Engine certs) for agent CONNECT frames, with frame routing between agent socket and browser bridge queues; Engine tunnel service checklist marked complete. - 2025-11-30: Added idle/grace sweeper, CONNECT_ACK to agents, heartbeat loop, and token-touched operator sends; per-port listener now runs on dedicated loop/thread. (Original instructions didn’t call out sweeper/heartbeat wiring explicitly.) +- 2025-12-01: Added Agent reverse tunnel role (`Data/Agent/Roles/role_ReverseTunnel.py`) with TLS-aware WebSocket dialer, token validation against signed leases, domain-limit guard, heartbeat/idle watchdogs, and reverse_tunnel.log status emits; protocol handlers remain stubbed until PowerShell module lands. +- 2025-12-01: Implemented Agent PowerShell channel (pywinpty ConPTY stdin/stdout piping, resize, exit-close) and Engine PowerShell handler with Socket.IO helpers (`ps_open`/`ps_send`/`ps_resize`/`ps_poll`); added ps channel logging and domain-aware attach. WebUI remains pending. ## Engine Tunnel Service Architecture @@ -288,7 +290,7 @@ sequenceDiagram ``` ## Future Changes in Generation 2 -These items are out of scope for the current milestone but should be considered for a production-ready generation after minimum functionality is achieved in the early stages of development. +These items are out of scope for the current milestone but should be considered for a production-ready generation after minimum functionality is achieved in the early stages of development. This section is a place to note things that were not implemented in Generation 1, but should be added in future iterations of the Reverse Tunneling system. - Harden operator auth/authorization: enforce per-operator session binding, ownership checks, audited attach/detach, and offer a pure WebSocket `/ws/tunnel/` bridge. - Replace Socket.IO browser bridge with a dedicated binary WebSocket bridge for higher throughput and simpler framing.