Agent Reverse Tunneling - Agent Role Implementation

This commit is contained in:
2025-12-01 02:40:23 -07:00
parent 478237f487
commit fe4511ecaf
8 changed files with 1155 additions and 11 deletions

View File

@@ -0,0 +1,2 @@
"""Reverse tunnel protocol modules (placeholder package)."""

View File

@@ -0,0 +1,226 @@
"""PowerShell channel implementation for reverse tunnel (Agent side)."""
from __future__ import annotations
import asyncio
import os
import sys
from typing import Any, Dict, Optional
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
MSG_DATA = 0x05
MSG_WINDOW_UPDATE = 0x06
MSG_CONTROL = 0x09
MSG_CLOSE = 0x08
# Close codes (mirrored from engine framing)
CLOSE_OK = 0
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AGENT_SHUTDOWN = 6
class PowershellChannel:
def __init__(self, role, tunnel, channel_id: int, metadata: Optional[Dict[str, Any]]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
self.loop = getattr(role, "loop", None) or asyncio.get_event_loop()
self._closed = False
self._reader_task = None
self._writer_task = None
self._stdin_queue: asyncio.Queue = asyncio.Queue()
self._pty = None
self._exit_code: Optional[int] = None
self._frame_cls = getattr(role, "_frame_cls", None)
# ------------------------------------------------------------------ Helpers
def _make_frame(self, msg_type: int, payload: bytes = b"", *, flags: int = 0):
frame_cls = self._frame_cls
if frame_cls is None:
return None
try:
return frame_cls(msg_type=msg_type, channel_id=self.channel_id, payload=payload or b"", flags=flags)
except Exception:
return None
async def _send_frame(self, frame) -> None:
if frame is None:
return
await self.role._send_frame(self.tunnel, frame)
async def _send_close(self, code: int, reason: str) -> None:
try:
close_frame = getattr(self.role, "close_frame")
if callable(close_frame):
await self._send_frame(close_frame(self.channel_id, code, reason))
return
except Exception:
pass
frame = self._make_frame(
MSG_CLOSE,
payload=f'{{"code":{code},"reason":"{reason}"}}'.encode("utf-8"),
)
await self._send_frame(frame)
def _powershell_path(self) -> str:
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
if isinstance(preferred, str) and preferred.strip():
return preferred.strip()
# Default to Windows PowerShell; fallback to pwsh if provided later.
return "powershell.exe"
def _initial_size(self) -> tuple:
cols = int(self.metadata.get("cols") or self.metadata.get("columns") or 120) if isinstance(self.metadata, dict) else 120
rows = int(self.metadata.get("rows") or 32) if isinstance(self.metadata, dict) else 32
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
return cols, rows
# ------------------------------------------------------------------ Lifecycle
async def start(self) -> None:
if sys.platform.lower().startswith("win") is False:
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
return
try:
import pywinpty # type: ignore
except Exception as exc: # pragma: no cover - dependency guard
self.role._log(f"reverse_tunnel ps channel missing pywinpty: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "pywinpty_missing")
return
shell = self._powershell_path()
cols, rows = self._initial_size()
try:
self._pty = pywinpty.Process(
spawn_cmd=shell,
dimensions=(cols, rows),
)
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel failed to spawn {shell}: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
return
self._reader_task = self.loop.create_task(self._pump_stdout())
self._writer_task = self.loop.create_task(self._pump_stdin())
self.role._log(f"reverse_tunnel ps channel started shell={shell} cols={cols} rows={rows}")
async def on_frame(self, frame) -> None:
if self._closed:
return
if frame.msg_type == MSG_DATA:
if frame.payload:
try:
self._stdin_queue.put_nowait(frame.payload)
except Exception:
await self._stdin_queue.put(frame.payload)
elif frame.msg_type == MSG_CONTROL:
await self._handle_control(frame.payload)
elif frame.msg_type == MSG_CLOSE:
await self.stop(code=CLOSE_AGENT_SHUTDOWN, reason="operator_close")
elif frame.msg_type == MSG_WINDOW_UPDATE:
# Reserved for back-pressure; ignore for now.
return
async def _handle_control(self, payload: bytes) -> None:
try:
import json
data = json.loads(payload.decode("utf-8"))
except Exception:
return
cols = data.get("cols") or data.get("columns")
rows = data.get("rows")
if cols is None and rows is None:
return
try:
cols_int = int(cols) if cols is not None else None
rows_int = int(rows) if rows is not None else None
except Exception:
return
await self._resize(cols_int, rows_int)
async def _resize(self, cols: Optional[int], rows: Optional[int]) -> None:
if self._pty is None:
return
try:
cur_cols, cur_rows = self._initial_size()
if cols is None:
cols = cur_cols
if rows is None:
rows = cur_rows
cols = max(20, min(int(cols), 300))
rows = max(10, min(int(rows), 200))
self._pty.set_size(cols, rows)
self.role._log(f"reverse_tunnel ps channel resized cols={cols} rows={rows}")
except Exception:
self.role._log("reverse_tunnel ps channel resize failed", error=True)
async def _pump_stdout(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
chunk = await loop.run_in_executor(None, self._pty.read, 4096)
if chunk is None:
break
if isinstance(chunk, str):
data = chunk.encode("utf-8", errors="replace")
else:
data = bytes(chunk)
if not data:
break
frame = self._make_frame(MSG_DATA, payload=data)
await self._send_frame(frame)
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdout pump error", error=True)
finally:
await self.stop(reason="stdout_closed")
async def _pump_stdin(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
try:
data = await self._stdin_queue.get()
except asyncio.CancelledError:
break
if data is None:
break
if isinstance(data, (bytes, bytearray)):
text = data.decode("utf-8", errors="replace")
else:
text = str(data)
try:
await loop.run_in_executor(None, self._pty.write, text)
except Exception:
break
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
if self._closed:
return
self._closed = True
if self._pty is not None:
try:
self._pty.terminate()
except Exception:
pass
current = asyncio.current_task()
if self._reader_task and self._reader_task is not current:
try:
self._reader_task.cancel()
except Exception:
pass
if self._writer_task and self._writer_task is not current:
try:
self._writer_task.cancel()
except Exception:
pass
await self._send_close(code, reason or "powershell_exit")
self.role._log(f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={reason or 'exit'}")

View File

@@ -0,0 +1,654 @@
import asyncio
import base64
import json
import os
import struct
import time
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from urllib.parse import urlparse
import aiohttp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
try:
from .ReverseTunnel import tunnel_Powershell
except Exception:
tunnel_Powershell = None
ROLE_NAME = "reverse_tunnel"
ROLE_CONTEXTS = ["interactive", "system"]
# Message types (keep in sync with Engine service)
MSG_CONNECT = 0x01
MSG_CONNECT_ACK = 0x02
MSG_CHANNEL_OPEN = 0x03
MSG_CHANNEL_ACK = 0x04
MSG_DATA = 0x05
MSG_WINDOW_UPDATE = 0x06
MSG_HEARTBEAT = 0x07
MSG_CLOSE = 0x08
MSG_CONTROL = 0x09
# Close codes
CLOSE_OK = 0
CLOSE_IDLE_TIMEOUT = 1
CLOSE_GRACE_EXPIRED = 2
CLOSE_PROTOCOL_ERROR = 3
CLOSE_AUTH_FAILED = 4
CLOSE_SERVER_SHUTDOWN = 5
CLOSE_AGENT_SHUTDOWN = 6
CLOSE_DOMAIN_LIMIT = 7
CLOSE_UNEXPECTED_DISCONNECT = 8
FRAME_VERSION = 1
FRAME_HEADER_STRUCT = struct.Struct("<BBBBII") # version, msg_type, flags, reserved, channel_id, length
@dataclass
class TunnelFrame:
msg_type: int
channel_id: int
payload: bytes = field(default_factory=bytes)
flags: int = 0
version: int = FRAME_VERSION
reserved: int = 0
def encode(self) -> bytes:
payload = self.payload or b""
header = FRAME_HEADER_STRUCT.pack(
self.version,
self.msg_type,
self.flags,
self.reserved,
int(self.channel_id),
len(payload),
)
return header + payload
def decode_frame(raw: bytes) -> TunnelFrame:
if len(raw) < FRAME_HEADER_STRUCT.size:
raise ValueError("frame_too_small")
version, msg_type, flags, reserved, channel_id, length = FRAME_HEADER_STRUCT.unpack_from(raw, 0)
if version != FRAME_VERSION:
raise ValueError(f"unsupported_version:{version}")
if length < 0 or len(raw) < FRAME_HEADER_STRUCT.size + length:
raise ValueError("invalid_length")
payload = raw[FRAME_HEADER_STRUCT.size : FRAME_HEADER_STRUCT.size + length]
if len(payload) != length:
raise ValueError("length_mismatch")
return TunnelFrame(msg_type=msg_type, channel_id=channel_id, payload=payload, flags=flags, version=version, reserved=reserved)
def heartbeat_frame(channel_id: int = 0, *, is_ack: bool = False) -> TunnelFrame:
flags = 0x1 if is_ack else 0x0
return TunnelFrame(msg_type=MSG_HEARTBEAT, channel_id=channel_id, flags=flags, payload=b"")
def close_frame(channel_id: int, code: int, reason: str = "") -> TunnelFrame:
payload = json.dumps({"code": code, "reason": reason}, separators=(",", ":")).encode("utf-8")
return TunnelFrame(msg_type=MSG_CLOSE, channel_id=channel_id, payload=payload)
def _norm_text(value: Any) -> str:
if value is None:
return ""
try:
return str(value).strip()
except Exception:
return ""
def _is_literal_ip(value: str) -> bool:
try:
import ipaddress
ipaddress.ip_address(value.strip().strip("[]"))
return True
except Exception:
return False
@dataclass
class ActiveTunnel:
tunnel_id: str
domain: str
protocol: str
port: int
token: str
url: str
heartbeat_seconds: int
idle_seconds: int
grace_seconds: int
expires_at: Optional[int]
signing_key_hint: Optional[str] = None
session: Optional[aiohttp.ClientSession] = None
websocket: Optional[aiohttp.ClientWebSocketResponse] = None
tasks: list = field(default_factory=list)
send_queue: asyncio.Queue = field(default_factory=asyncio.Queue)
channels: Dict[int, Any] = field(default_factory=dict)
last_activity: float = field(default_factory=lambda: time.time())
connected: bool = False
stopping: bool = False
stop_reason: Optional[str] = None
class BaseChannel:
"""Placeholder channel handler; protocol-specific handlers plug in later."""
def __init__(self, role: "Role", tunnel: ActiveTunnel, channel_id: int, metadata: Optional[dict]):
self.role = role
self.tunnel = tunnel
self.channel_id = channel_id
self.metadata = metadata or {}
async def start(self) -> None:
# Nothing to prime for placeholder channels.
return
async def on_frame(self, frame: TunnelFrame) -> None:
# Drop frames until protocol module is provided.
return
async def stop(self, code: int = CLOSE_OK, reason: str = "") -> None:
await self.role._send_frame(self.tunnel, close_frame(self.channel_id, code, reason))
class Role:
def __init__(self, ctx):
self.ctx = ctx
self.sio = ctx.sio
self.loop = ctx.loop or asyncio.get_event_loop()
self.hooks = ctx.hooks or {}
self._http_client_factory = self.hooks.get("http_client")
self._log_hook = self.hooks.get("log_agent")
self._active: Dict[str, ActiveTunnel] = {}
self._domain_claims: Dict[str, str] = {}
self._domain_limits: Dict[str, Optional[int]] = {
"ps": 1,
"rdp": 1,
"vnc": 1,
"webrtc": 1,
"ssh": None,
"winrm": None,
}
self._default_heartbeat = 20
self._protocol_handlers: Dict[str, Any] = {}
self._frame_cls = TunnelFrame
self.close_frame = close_frame
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
# ------------------------------------------------------------------ Logging
def _log(self, message: str, *, error: bool = False) -> None:
fname = "reverse_tunnel.log"
try:
if callable(self._log_hook):
self._log_hook(message, fname=fname)
if error:
self._log_hook(message, fname="agent.error.log")
except Exception:
pass
# ------------------------------------------------------------------ Event wiring
def register_events(self):
@self.sio.on("reverse_tunnel_start")
async def _reverse_tunnel_start(payload):
await self._handle_tunnel_start(payload)
@self.sio.on("reverse_tunnel_stop")
async def _reverse_tunnel_stop(payload):
tid = ""
if isinstance(payload, dict):
tid = _norm_text(payload.get("tunnel_id"))
await self._stop_tunnel(tid, code=CLOSE_AGENT_SHUTDOWN, reason="server_stop")
# ------------------------------------------------------------------ Token helpers
def _decode_token_payload(self, token: str, *, signing_key_hint: Optional[str] = None) -> Dict[str, Any]:
if not token:
raise ValueError("token_missing")
def _b64decode(segment: str) -> bytes:
padding = "=" * (-len(segment) % 4)
return base64.urlsafe_b64decode(segment + padding)
parts = token.split(".")
payload_segment = parts[0]
payload_bytes = _b64decode(payload_segment)
try:
payload = json.loads(payload_bytes.decode("utf-8"))
except Exception as exc:
raise ValueError("token_decode_error") from exc
if len(parts) == 2:
candidates = []
hint = _norm_text(signing_key_hint)
if hint:
candidates.append(hint)
client = self._http_client()
if client and hasattr(client, "load_server_signing_key"):
try:
stored = client.load_server_signing_key()
except Exception:
stored = None
if isinstance(stored, str) and stored.strip():
candidates.append(stored.strip())
signature = _b64decode(parts[1])
verified = False
for candidate in candidates:
try:
key_bytes = base64.b64decode(candidate, validate=True)
public_key = serialization.load_der_public_key(key_bytes)
except Exception:
continue
if not isinstance(public_key, ed25519.Ed25519PublicKey):
continue
try:
public_key.verify(signature, payload_bytes)
verified = True
if client and hasattr(client, "store_server_signing_key"):
try:
client.store_server_signing_key(candidate)
except Exception:
pass
break
except Exception:
continue
if not verified:
raise ValueError("token_signature_invalid")
return payload
def _validate_token(self, token: str, *, expected_agent: str, expected_domain: str, expected_protocol: str, expected_tunnel: str, signing_key_hint: Optional[str]) -> Dict[str, Any]:
payload = self._decode_token_payload(token, signing_key_hint=signing_key_hint)
def _matches(expected: str, actual: Any) -> bool:
return _norm_text(expected).lower() == _norm_text(actual).lower()
if expected_agent and not _matches(expected_agent, payload.get("agent_id")):
raise ValueError("token_agent_mismatch")
if expected_tunnel and not _matches(expected_tunnel, payload.get("tunnel_id")):
raise ValueError("token_id_mismatch")
if expected_domain and not _matches(expected_domain, payload.get("domain")):
raise ValueError("token_domain_mismatch")
if expected_protocol and not _matches(expected_protocol, payload.get("protocol")):
raise ValueError("token_protocol_mismatch")
expires_at = payload.get("expires_at")
try:
expires_ts = int(expires_at) if expires_at is not None else None
except Exception:
expires_ts = None
if expires_ts is not None and expires_ts < int(time.time()):
raise ValueError("token_expired")
return payload
# ------------------------------------------------------------------ Utility
def _http_client(self):
try:
if callable(self._http_client_factory):
return self._http_client_factory()
except Exception:
return None
return None
def _port_allowed(self, port: int) -> bool:
return isinstance(port, int) and 1 <= port <= 65535
def _domain_allowed(self, domain: str) -> bool:
limit = self._domain_limits.get(domain)
if limit is None:
return True
active = [tid for tid, t in self._active.items() if t.domain == domain and not t.stopping]
pending = [tid for dom, tid in self._domain_claims.items() if dom == domain and tid not in active]
count = len(set(active + pending))
return count < limit
def _build_ws_url(self, port: int) -> str:
client = self._http_client()
if client:
try:
client.refresh_base_url()
except Exception:
pass
base_url = getattr(client, "base_url", "") or "https://localhost:5000"
else:
base_url = "https://localhost:5000"
parsed = urlparse(base_url)
host = parsed.hostname or "localhost"
scheme = "wss" if (parsed.scheme or "").lower() == "https" else "ws"
return f"{scheme}://{host}:{port}/"
def _ssl_context(self, host: str):
client = self._http_client()
verify = getattr(getattr(client, "session", None), "verify", True)
if verify is False:
return False
if isinstance(verify, str) and os.path.isfile(verify) and client and hasattr(client, "key_store"):
try:
ctx = client.key_store.build_ssl_context()
if ctx and _is_literal_ip(host):
ctx.check_hostname = False
return ctx
except Exception:
return None
return None
async def _emit_status(self, payload: Dict[str, Any]) -> None:
try:
await self.sio.emit("reverse_tunnel_status", payload)
except Exception:
pass
def _mark_activity(self, tunnel: ActiveTunnel) -> None:
tunnel.last_activity = time.time()
# ------------------------------------------------------------------ Event handlers
async def _handle_tunnel_start(self, payload: Any) -> None:
if not isinstance(payload, dict):
self._log("reverse_tunnel_start ignored: payload not a dict", error=True)
return
tunnel_id = _norm_text(payload.get("tunnel_id"))
token = _norm_text(payload.get("token"))
port = payload.get("port") or payload.get("assigned_port")
protocol = _norm_text(payload.get("protocol") or "ps").lower() or "ps"
domain = _norm_text(payload.get("domain") or protocol).lower() or protocol
heartbeat_seconds = int(payload.get("heartbeat_seconds") or self._default_heartbeat or 20)
idle_seconds = int(payload.get("idle_seconds") or 3600)
grace_seconds = int(payload.get("grace_seconds") or 3600)
signing_key_hint = _norm_text(payload.get("signing_key"))
if not token:
self._log("reverse_tunnel_start rejected: missing token", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_missing"})
return
if not tunnel_id:
tunnel_id = _norm_text(payload.get("lease_id")) # fallback alias
try:
claims = self._validate_token(
token,
expected_agent=self.ctx.agent_id,
expected_domain=domain,
expected_protocol=protocol,
expected_tunnel=tunnel_id,
signing_key_hint=signing_key_hint,
)
except Exception as exc:
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
return
tunnel_id = _norm_text(claims.get("tunnel_id") or tunnel_id)
if not tunnel_id:
self._log("reverse_tunnel_start rejected: tunnel_id missing after token parse", error=True)
return
try:
port = int(port or claims.get("assigned_port") or 0)
except Exception:
port = 0
if not self._port_allowed(port):
self._log(f"reverse_tunnel_start rejected: invalid port {port}", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "invalid_port"})
return
domain = _norm_text(claims.get("domain") or domain).lower()
protocol = _norm_text(claims.get("protocol") or protocol).lower()
expires_at = claims.get("expires_at")
if tunnel_id in self._active:
self._log(f"reverse_tunnel_start ignored: tunnel already active tunnel_id={tunnel_id}")
return
if not self._domain_allowed(domain):
self._log(f"reverse_tunnel_start rejected: domain limit for {domain}", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "domain_limit"})
return
url = self._build_ws_url(port)
parsed = urlparse(url)
heartbeat_seconds = max(5, min(heartbeat_seconds, 120))
idle_seconds = max(30, idle_seconds)
grace_seconds = max(60, grace_seconds)
tunnel = ActiveTunnel(
tunnel_id=tunnel_id,
domain=domain,
protocol=protocol,
port=port,
token=token,
url=url,
heartbeat_seconds=heartbeat_seconds,
idle_seconds=idle_seconds,
grace_seconds=grace_seconds,
expires_at=int(expires_at) if expires_at is not None else None,
signing_key_hint=signing_key_hint or None,
)
self._active[tunnel_id] = tunnel
self._domain_claims[domain] = tunnel_id
self._log(f"reverse_tunnel_start accepted tunnel_id={tunnel_id} domain={domain} protocol={protocol} url={url}")
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "connecting", "url": url})
task = self.loop.create_task(self._run_tunnel(tunnel, host=parsed.hostname or "localhost"))
tunnel.tasks.append(task)
# ------------------------------------------------------------------ Core tunnel handling
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
ssl_ctx = self._ssl_context(host)
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
try:
tunnel.session = aiohttp.ClientSession(timeout=timeout)
tunnel.websocket = await tunnel.session.ws_connect(
tunnel.url,
ssl=ssl_ctx,
heartbeat=None,
max_msg_size=0,
timeout=timeout,
)
self._mark_activity(tunnel)
await tunnel.websocket.send_bytes(
TunnelFrame(
msg_type=MSG_CONNECT,
channel_id=0,
payload=json.dumps(
{
"agent_id": self.ctx.agent_id,
"tunnel_id": tunnel.tunnel_id,
"token": tunnel.token,
"protocol": tunnel.protocol,
"domain": tunnel.domain,
"version": FRAME_VERSION,
},
separators=(",", ":"),
).encode("utf-8"),
).encode()
)
sender = self.loop.create_task(self._pump_sender(tunnel))
receiver = self.loop.create_task(self._pump_receiver(tunnel))
heartbeats = self.loop.create_task(self._heartbeat_loop(tunnel))
watchdog = self.loop.create_task(self._watchdog(tunnel))
tunnel.tasks.extend([sender, receiver, heartbeats, watchdog])
await asyncio.wait([sender, receiver, heartbeats, watchdog], return_when=asyncio.FIRST_COMPLETED)
except Exception as exc:
self._log(f"reverse_tunnel connection failed tunnel_id={tunnel.tunnel_id}: {exc}", error=True)
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "connect_failed"})
finally:
await self._shutdown_tunnel(tunnel)
async def _pump_sender(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
frame: TunnelFrame = await tunnel.send_queue.get()
try:
await tunnel.websocket.send_bytes(frame.encode())
self._mark_activity(tunnel)
except Exception:
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
ws = tunnel.websocket
if ws is None:
return
try:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.BINARY:
try:
frame = decode_frame(msg.data)
except Exception:
self._log(f"reverse_tunnel frame decode failed tunnel_id={tunnel.tunnel_id}", error=True)
continue
self._mark_activity(tunnel)
await self._handle_frame(tunnel, frame)
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(tunnel.heartbeat_seconds)
await self._send_frame(tunnel, heartbeat_frame())
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(10)
now = time.time()
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
break
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
if frame.msg_type == MSG_HEARTBEAT:
if frame.flags & 0x1:
return
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
return
if frame.msg_type == MSG_CONNECT_ACK:
tunnel.connected = True
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
return
if frame.msg_type == MSG_CHANNEL_OPEN:
await self._handle_channel_open(tunnel, frame)
return
if frame.msg_type == MSG_CLOSE:
try:
reason = json.loads(frame.payload.decode("utf-8"))
except Exception:
reason = {"code": CLOSE_UNEXPECTED_DISCONNECT, "reason": "close"}
tunnel.stop_reason = reason.get("reason") if isinstance(reason, dict) else None
await self._shutdown_tunnel(tunnel, send_close=False)
return
handler = tunnel.channels.get(frame.channel_id)
if handler:
try:
await handler.on_frame(frame)
except Exception:
self._log(f"reverse_tunnel channel handler failed tunnel_id={tunnel.tunnel_id} channel={frame.channel_id}", error=True)
else:
if frame.msg_type in (MSG_DATA, MSG_CONTROL, MSG_WINDOW_UPDATE):
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "unknown_channel"))
async def _handle_channel_open(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
protocol = _norm_text(payload.get("protocol") or tunnel.protocol)
metadata = payload.get("metadata") if isinstance(payload, dict) else {}
if frame.channel_id in tunnel.channels:
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
return
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
try:
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
except Exception:
handler = BaseChannel(self, tunnel, frame.channel_id, metadata)
tunnel.channels[frame.channel_id] = handler
await handler.start()
await self._send_frame(
tunnel,
TunnelFrame(
msg_type=MSG_CHANNEL_ACK,
channel_id=frame.channel_id,
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
),
)
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
if tunnel.stopping:
return
try:
tunnel.send_queue.put_nowait(frame)
except Exception:
await tunnel.send_queue.put(frame)
async def _stop_tunnel(self, tunnel_id: str, *, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "requested") -> None:
tunnel = self._active.get(tunnel_id)
if not tunnel:
return
await self._send_frame(tunnel, close_frame(0, code, reason))
await self._shutdown_tunnel(tunnel, send_close=False)
async def _shutdown_tunnel(self, tunnel: ActiveTunnel, *, send_close: bool = True) -> None:
if tunnel.stopping:
return
tunnel.stopping = True
if send_close:
try:
await self._send_frame(tunnel, close_frame(0, CLOSE_AGENT_SHUTDOWN, "agent_shutdown"))
except Exception:
pass
for task in list(tunnel.tasks):
try:
task.cancel()
except Exception:
pass
if tunnel.websocket is not None:
try:
await tunnel.websocket.close()
except Exception:
pass
if tunnel.session is not None:
try:
await tunnel.session.close()
except Exception:
pass
self._active.pop(tunnel.tunnel_id, None)
if tunnel.domain in self._domain_claims and self._domain_claims.get(tunnel.domain) == tunnel.tunnel_id:
self._domain_claims.pop(tunnel.domain, None)
self._log(f"reverse_tunnel stopped tunnel_id={tunnel.tunnel_id} reason={tunnel.stop_reason or 'closed'}")
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "closed", "reason": tunnel.stop_reason or "closed"})
# ------------------------------------------------------------------ Lifecycle
def stop_all(self):
for tunnel_id in list(self._active.keys()):
try:
self.loop.create_task(self._stop_tunnel(tunnel_id, code=CLOSE_AGENT_SHUTDOWN, reason="agent_shutdown"))
except Exception:
pass

View File

@@ -29,6 +29,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
from collections import deque
from threading import Thread
from .ReverseTunnel.Powershell import PowershellChannelServer
try: # websockets is added to engine requirements
import websockets
from websockets.server import serve as ws_serve
@@ -447,6 +449,7 @@ class ReverseTunnelService:
self._bridges: Dict[str, "TunnelBridge"] = {}
self._port_servers: Dict[int, asyncio.AbstractServer] = {}
self._agent_sockets: Dict[str, "websockets.WebSocketServerProtocol"] = {}
self._ps_servers: Dict[str, PowershellChannelServer] = {}
def _ensure_loop(self) -> None:
if self._running and self._loop:
@@ -548,6 +551,13 @@ class ReverseTunnelService:
raise ValueError("unknown_tunnel")
bridge = self.ensure_bridge(lease)
bridge.attach_operator(operator_id)
if lease.domain.lower() == "ps":
try:
server = self.ensure_ps_server(tunnel_id)
if server:
server.open_channel()
except Exception:
self.logger.debug("ps server open failed tunnel_id=%s", tunnel_id, exc_info=True)
return bridge
def _encode_token(self, payload: Dict[str, object]) -> str:
@@ -841,6 +851,15 @@ class ReverseTunnelService:
except Exception:
pass
def _dispatch_agent_frame(self, tunnel_id: str, frame: TunnelFrame) -> None:
server = self._ps_servers.get(tunnel_id)
if not server:
return
try:
server.handle_agent_frame(frame)
except Exception:
self.logger.debug("ps handler error for tunnel_id=%s", tunnel_id, exc_info=True)
def _start_lease_sweeper(self) -> None:
async def _sweeper():
while self._running and self._loop and not self._loop.is_closed():
@@ -928,6 +947,10 @@ class ReverseTunnelService:
except Exception:
continue
self.lease_manager.touch(tunnel_id)
try:
self._dispatch_agent_frame(tunnel_id, recv_frame)
except Exception:
pass
bridge.agent_to_operator(recv_frame)
async def _pump_to_agent():
while not websocket.closed:
@@ -969,10 +992,30 @@ class ReverseTunnelService:
self._bridges[lease.tunnel_id] = bridge
return bridge
def ensure_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
server = self._ps_servers.get(tunnel_id)
if server:
return server
lease = self.lease_manager.get(tunnel_id)
if lease is None or (lease.domain or "").lower() != "ps":
return None
bridge = self.ensure_bridge(lease)
server = PowershellChannelServer(bridge=bridge, service=self)
self._ps_servers[tunnel_id] = server
return server
def get_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
return self._ps_servers.get(tunnel_id)
def release_bridge(self, tunnel_id: str, *, reason: str = "bridge_released") -> None:
bridge = self._bridges.pop(tunnel_id, None)
if bridge:
bridge.stop(reason=reason)
if tunnel_id in self._ps_servers:
try:
self._ps_servers.pop(tunnel_id, None)
except Exception:
pass
class TunnelBridge:

View File

@@ -0,0 +1,130 @@
"""Engine-side PowerShell tunnel channel helper."""
from __future__ import annotations
import json
from collections import deque
from typing import Any, Deque, Dict, List, Optional
from ..ReverseTunnel import (
CLOSE_AGENT_SHUTDOWN,
CLOSE_OK,
CLOSE_PROTOCOL_ERROR,
MSG_CHANNEL_ACK,
MSG_CHANNEL_OPEN,
MSG_CLOSE,
MSG_CONTROL,
MSG_DATA,
TunnelFrame,
close_frame,
)
class PowershellChannelServer:
"""Coordinate PowerShell channel frames over a TunnelBridge."""
def __init__(self, bridge, service, *, channel_id: int = 1):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output: Deque[str] = deque()
self._close_reason: Optional[str] = None
self._close_code: Optional[int] = None
# ------------------------------------------------------------------ Agent frame handling
def handle_agent_frame(self, frame: TunnelFrame) -> None:
if frame.channel_id != self.channel_id:
return
if frame.msg_type == MSG_CHANNEL_ACK:
self._ack_received = True
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
return
if frame.msg_type == MSG_DATA:
try:
text = frame.payload.decode("utf-8", errors="replace")
except Exception:
text = ""
if text:
self._append_output(text)
return
if frame.msg_type == MSG_CLOSE:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
self._closed = True
self._close_code = payload.get("code") if isinstance(payload, dict) else None
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
self.logger.info(
"ps channel closed tunnel_id=%s code=%s reason=%s",
self.bridge.lease.tunnel_id,
self._close_code,
self._close_reason or "-",
)
# ------------------------------------------------------------------ Operator actions
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
if self._open_sent:
return
payload = json.dumps(
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
separators=(",", ":"),
).encode("utf-8")
frame = TunnelFrame(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
self._open_sent = True
self.logger.info(
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
if self._closed:
return
payload = data.encode("utf-8", errors="replace")
frame = TunnelFrame(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def send_resize(self, cols: int, rows: int) -> None:
if self._closed:
return
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
frame = TunnelFrame(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
self.bridge.operator_to_agent(close_frame(self.channel_id, code, reason))
# ------------------------------------------------------------------ Output polling
def drain_output(self) -> List[str]:
items: List[str] = []
while self._output:
items.append(self._output.popleft())
return items
def _append_output(self, text: str) -> None:
self._output.append(text)
# Cap buffer to avoid unbounded memory growth.
while len(self._output) > 500:
self._output.popleft()
# ------------------------------------------------------------------ Status helpers
def status(self) -> Dict[str, Any]:
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": self._close_reason,
"close_code": self._close_code,
}

View File

@@ -0,0 +1,2 @@
"""Protocol-specific helpers for Reverse Tunnel (Engine side)."""

View File

@@ -401,6 +401,91 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_ps_server(tunnel_id)
if server is None:
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll() -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid

View File

@@ -206,17 +206,17 @@ Read `Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md` and
- [x] Implement channel framing, flow control, heartbeats, close semantics.
- [x] Logging: `Engine/Logs/reverse_tunnel.log`; audit into Device Activity (session start/stop, operator id, agent id, tunnel_id, port).
- [x] WebUI operator bridge endpoint (WebSocket) that maps browser sessions to agent channels.
- [x] Idle/grace sweeper + heartbeat wiring for tunnel sockets.
- [x] TLS-aware per-port listener and agent CONNECT_ACK handling.
- [ ] Agent tunnel role
- [ ] Add `Data/Agent/Roles/role_ReverseTunnel.py` (manages tunnel socket, reconnect, heartbeats, channel dispatch).
- [ ] Per-protocol submodules under `Data/Agent/Roles/ReverseTunnel/` (first: `tunnel_Powershell.py`).
- [ ] Enforce per-domain concurrency (one PowerShell; prevent multiple RDP/VNC/WebRTC; allow extensible policies).
- [ ] Logging: `Agent/Logs/reverse_tunnel.log`; include tunnel_id/channel_id.
- [ ] Integrate token validation, TLS reuse, idle teardown, and graceful stop_all.
- [x] Idle/grace sweeper + heartbeat wiring for tunnel sockets.
- [x] TLS-aware per-port listener and agent CONNECT_ACK handling.
- [x] Agent tunnel role
- [x] Add `Data/Agent/Roles/role_ReverseTunnel.py` (manages tunnel socket, reconnect, heartbeats, channel dispatch).
- [x] Per-protocol submodules under `Data/Agent/Roles/ReverseTunnel/` (first: `tunnel_Powershell.py`).
- [x] Enforce per-domain concurrency (one PowerShell; prevent multiple RDP/VNC/WebRTC; allow extensible policies).
- [x] Logging: `Agent/Logs/reverse_tunnel.log`; include tunnel_id/channel_id.
- [x] Integrate token validation, TLS reuse, idle teardown, and graceful stop_all.
- [ ] PowerShell v1 (feature target)
- [ ] Engine side `Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py` (channel server, resize handling, translate browser events).
- [ ] Agent side `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` using ConPTY/pywinpty; map stdin/stdout to frames; handle resize and exit codes.
- [x] Engine side `Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py` (channel server, resize handling, translate browser events).
- [x] Agent side `Data/Agent/Roles/ReverseTunnel/tunnel_Powershell.py` using ConPTY/pywinpty; map stdin/stdout to frames; handle resize and exit codes.
- [ ] WebUI: `Data/Engine/web-interface/src/ReverseTunnel/Powershell.jsx` with terminal UI, syntax highlighting matching `Assemblies/Assembly_Editor.jsx`, copy support, status toasts.
- [ ] Device Activity entries and UI surface in `Devices/Device_List.jsx` Device Activity tab.
- [ ] Credits & attribution
@@ -244,6 +244,8 @@ Read `Docs/Codex/FEATURE_IMPLEMENTATION_TRACKING/Agent_Reverse_Tunneling.md` and
- 2025-11-30: Added WebUI-facing Socket.IO namespace `/tunnel` with join/send/poll events that map browser sessions to tunnel bridges, using base64-encoded frames and operator auth from session/cookies.
- 2025-11-30: Enabled async WebSocket listener per assigned port (TLS-aware via Engine certs) for agent CONNECT frames, with frame routing between agent socket and browser bridge queues; Engine tunnel service checklist marked complete.
- 2025-11-30: Added idle/grace sweeper, CONNECT_ACK to agents, heartbeat loop, and token-touched operator sends; per-port listener now runs on dedicated loop/thread. (Original instructions didnt call out sweeper/heartbeat wiring explicitly.)
- 2025-12-01: Added Agent reverse tunnel role (`Data/Agent/Roles/role_ReverseTunnel.py`) with TLS-aware WebSocket dialer, token validation against signed leases, domain-limit guard, heartbeat/idle watchdogs, and reverse_tunnel.log status emits; protocol handlers remain stubbed until PowerShell module lands.
- 2025-12-01: Implemented Agent PowerShell channel (pywinpty ConPTY stdin/stdout piping, resize, exit-close) and Engine PowerShell handler with Socket.IO helpers (`ps_open`/`ps_send`/`ps_resize`/`ps_poll`); added ps channel logging and domain-aware attach. WebUI remains pending.
## Engine Tunnel Service Architecture
@@ -288,7 +290,7 @@ sequenceDiagram
```
## Future Changes in Generation 2
These items are out of scope for the current milestone but should be considered for a production-ready generation after minimum functionality is achieved in the early stages of development.
These items are out of scope for the current milestone but should be considered for a production-ready generation after minimum functionality is achieved in the early stages of development. This section is a place to note things that were not implemented in Generation 1, but should be added in future iterations of the Reverse Tunneling system.
- Harden operator auth/authorization: enforce per-operator session binding, ownership checks, audited attach/detach, and offer a pure WebSocket `/ws/tunnel/<tunnel_id>` bridge.
- Replace Socket.IO browser bridge with a dedicated binary WebSocket bridge for higher throughput and simpler framing.