Major Progress Towards Interactive Remote Powershell

This commit is contained in:
2025-12-06 00:27:57 -07:00
parent 52e40c3753
commit 68dd46347b
9 changed files with 1247 additions and 53 deletions

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import os
import sys
import subprocess
from typing import Any, Dict, Optional
# Message types mirrored from the tunnel framing (kept local to avoid import cycles).
@@ -30,6 +31,7 @@ class PowershellChannel:
self._writer_task = None
self._stdin_queue: asyncio.Queue = asyncio.Queue()
self._pty = None
self._proc: Optional[asyncio.subprocess.Process] = None
self._exit_code: Optional[int] = None
self._frame_cls = getattr(role, "_frame_cls", None)
@@ -62,12 +64,11 @@ class PowershellChannel:
)
await self._send_frame(frame)
def _powershell_path(self) -> str:
def _powershell_argv(self) -> list:
preferred = self.metadata.get("shell") if isinstance(self.metadata, dict) else None
if isinstance(preferred, str) and preferred.strip():
return preferred.strip()
# Default to Windows PowerShell; fallback to pwsh if provided later.
return "powershell.exe"
shell = preferred.strip() if isinstance(preferred, str) and preferred.strip() else "powershell.exe"
# Keep the process alive and read commands from stdin; -Command - tells PS to consume stdin.
return [shell, "-NoLogo", "-NoProfile", "-NoExit", "-Command", "-"]
def _initial_size(self) -> tuple:
cols = int(self.metadata.get("cols") or self.metadata.get("columns") or 120) if isinstance(self.metadata, dict) else 120
@@ -79,30 +80,46 @@ class PowershellChannel:
# ------------------------------------------------------------------ Lifecycle
async def start(self) -> None:
if sys.platform.lower().startswith("win") is False:
self.role._log("reverse_tunnel ps start aborted: non-windows platform", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "windows_only")
return
argv = self._powershell_argv()
cols, rows = self._initial_size()
self.role._log(f"reverse_tunnel ps start channel={self.channel_id} argv={' '.join(argv)} cols={cols} rows={rows}")
# Preferred: ConPTY via pywinpty.
try:
import pywinpty # type: ignore
except Exception as exc: # pragma: no cover - dependency guard
self.role._log(f"reverse_tunnel ps channel missing pywinpty: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "pywinpty_missing")
return
shell = self._powershell_path()
cols, rows = self._initial_size()
try:
self._pty = pywinpty.Process(
spawn_cmd=shell,
spawn_cmd=" ".join(argv[:-1]) if argv[-1] == "-" else " ".join(argv),
dimensions=(cols, rows),
)
self._reader_task = self.loop.create_task(self._pump_pty_stdout())
self._writer_task = self.loop.create_task(self._pump_pty_stdin())
self.role._log(f"reverse_tunnel ps channel started (pty) argv={' '.join(argv)} cols={cols} rows={rows}")
return
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel failed to spawn {shell}: {exc}", error=True)
self.role._log(f"reverse_tunnel ps channel pywinpty unavailable, falling back to pipes: {exc}", error=True)
# Fallback: subprocess pipes (no PTY).
try:
self._proc = await asyncio.create_subprocess_exec(
*argv,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
)
except Exception as exc:
self.role._log(f"reverse_tunnel ps channel spawn failed argv={' '.join(argv)}: {exc}", error=True)
await self._send_close(CLOSE_PROTOCOL_ERROR, "spawn_failed")
return
self._reader_task = self.loop.create_task(self._pump_stdout())
self._writer_task = self.loop.create_task(self._pump_stdin())
self.role._log(f"reverse_tunnel ps channel started shell={shell} cols={cols} rows={rows}")
self._reader_task = self.loop.create_task(self._pump_proc_stdout())
self._writer_task = self.loop.create_task(self._pump_proc_stdin())
self.role._log(f"reverse_tunnel ps channel started (pipes) argv={' '.join(argv)} cols={cols} rows={rows}")
async def on_frame(self, frame) -> None:
if self._closed:
@@ -140,6 +157,7 @@ class PowershellChannel:
await self._resize(cols_int, rows_int)
async def _resize(self, cols: Optional[int], rows: Optional[int]) -> None:
# Resize only applies to PTY sessions; pipe mode ignores.
if self._pty is None:
return
try:
@@ -155,17 +173,14 @@ class PowershellChannel:
except Exception:
self.role._log("reverse_tunnel ps channel resize failed", error=True)
async def _pump_stdout(self) -> None:
async def _pump_pty_stdout(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
chunk = await loop.run_in_executor(None, self._pty.read, 4096)
if chunk is None:
break
if isinstance(chunk, str):
data = chunk.encode("utf-8", errors="replace")
else:
data = bytes(chunk)
data = chunk.encode("utf-8", errors="replace") if isinstance(chunk, str) else bytes(chunk)
if not data:
break
frame = self._make_frame(MSG_DATA, payload=data)
@@ -173,11 +188,11 @@ class PowershellChannel:
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdout pump error", error=True)
self.role._log("reverse_tunnel ps pty stdout pump error", error=True)
finally:
await self.stop(reason="stdout_closed")
async def _pump_stdin(self) -> None:
async def _pump_pty_stdin(self) -> None:
loop = asyncio.get_event_loop()
try:
while not self._closed and self._pty:
@@ -187,10 +202,7 @@ class PowershellChannel:
break
if data is None:
break
if isinstance(data, (bytes, bytearray)):
text = data.decode("utf-8", errors="replace")
else:
text = str(data)
text = data.decode("utf-8", errors="replace") if isinstance(data, (bytes, bytearray)) else str(data)
try:
await loop.run_in_executor(None, self._pty.write, text)
except Exception:
@@ -198,7 +210,47 @@ class PowershellChannel:
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps stdin pump error", error=True)
self.role._log("reverse_tunnel ps pty stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
# -------------------- Pipe fallback pumps --------------------
async def _pump_proc_stdout(self) -> None:
try:
while self._proc and not self._closed:
chunk = await self._proc.stdout.read(4096)
if not chunk:
break
frame = self._make_frame(MSG_DATA, payload=bytes(chunk))
await self._send_frame(frame)
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdout pump error", error=True)
finally:
if self._proc and not self._closed:
try:
self._exit_code = await self._proc.wait()
except Exception:
pass
await self.stop(reason="stdout_closed")
async def _pump_proc_stdin(self) -> None:
try:
while self._proc and not self._closed:
data = await self._stdin_queue.get()
if self._closed or not self._proc or not self._proc.stdin:
break
try:
self._proc.stdin.write(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
await self._proc.stdin.drain()
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
break
except asyncio.CancelledError:
pass
except Exception:
self.role._log("reverse_tunnel ps pipe stdin pump error", error=True)
finally:
await self.stop(reason="stdin_closed")
@@ -211,6 +263,11 @@ class PowershellChannel:
self._pty.terminate()
except Exception:
pass
if self._proc is not None:
try:
self._proc.terminate()
except Exception:
pass
current = asyncio.current_task()
if self._reader_task and self._reader_task is not current:
try:
@@ -222,5 +279,7 @@ class PowershellChannel:
self._writer_task.cancel()
except Exception:
pass
await self._send_close(code, reason or "powershell_exit")
# Include exit code in the close reason for debugging.
exit_suffix = f" (exit={self._exit_code})" if self._exit_code is not None else ""
await self._send_close(code, (reason or "powershell_exit") + exit_suffix)
self.role._log(f"reverse_tunnel ps channel stopped channel={self.channel_id} reason={reason or 'exit'}")

View File

@@ -1,10 +1,12 @@
import asyncio
import base64
import importlib.util
import json
import os
import struct
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import urlparse
@@ -12,10 +14,25 @@ import aiohttp
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
# Capture import errors for the PowerShell handler so we can report why it is missing.
PS_IMPORT_ERROR: Optional[str] = None
tunnel_Powershell = None
try:
from .ReverseTunnel import tunnel_Powershell
except Exception:
tunnel_Powershell = None
from .ReverseTunnel import tunnel_Powershell # type: ignore
except Exception as exc: # pragma: no cover - best-effort logging only
PS_IMPORT_ERROR = repr(exc)
# Try manual import from file to survive non-package execution.
try:
_ps_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
if _ps_path.exists():
spec = importlib.util.spec_from_file_location("tunnel_Powershell", _ps_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
tunnel_Powershell = module
PS_IMPORT_ERROR = None
except Exception as exc2: # pragma: no cover - diagnostic only
PS_IMPORT_ERROR = f"{PS_IMPORT_ERROR} | fallback_load_failed={exc2!r}"
ROLE_NAME = "reverse_tunnel"
ROLE_CONTEXTS = ["interactive", "system"]
@@ -178,8 +195,21 @@ class Role:
self._protocol_handlers: Dict[str, Any] = {}
self._frame_cls = TunnelFrame
self.close_frame = close_frame
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
try:
if tunnel_Powershell and hasattr(tunnel_Powershell, "PowershellChannel"):
self._protocol_handlers["ps"] = tunnel_Powershell.PowershellChannel
module_path = getattr(tunnel_Powershell, "__file__", None)
self._log(f"reverse_tunnel ps handler registered (PowershellChannel) module={module_path}")
else:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler NOT registered (missing module/class){hint}{exists_hint}",
error=True,
)
except Exception as exc:
self._log(f"reverse_tunnel ps handler registration failed: {exc}", error=True)
# ------------------------------------------------------------------ Logging
def _log(self, message: str, *, error: bool = False) -> None:
@@ -359,6 +389,11 @@ class Role:
idle_seconds = int(payload.get("idle_seconds") or 3600)
grace_seconds = int(payload.get("grace_seconds") or 3600)
signing_key_hint = _norm_text(payload.get("signing_key"))
agent_hint = _norm_text(payload.get("agent_id"))
# Ignore broadcasts targeting other agents (Socket.IO fanout sends to both contexts).
if agent_hint and agent_hint.lower() != _norm_text(self.ctx.agent_id).lower():
return
if not token:
self._log("reverse_tunnel_start rejected: missing token", error=True)
@@ -378,6 +413,9 @@ class Role:
signing_key_hint=signing_key_hint,
)
except Exception as exc:
if str(exc) == "token_agent_mismatch":
# Broadcast hit the wrong agent context; ignore quietly.
return
self._log(f"reverse_tunnel_start rejected: token validation failed ({exc})", error=True)
await self._emit_status({"tunnel_id": tunnel_id, "agent_id": self.ctx.agent_id, "status": "error", "reason": "token_invalid"})
return
@@ -439,6 +477,10 @@ class Role:
async def _run_tunnel(self, tunnel: ActiveTunnel, *, host: str) -> None:
ssl_ctx = self._ssl_context(host)
timeout = aiohttp.ClientTimeout(total=None, sock_connect=10, sock_read=None)
self._log(
f"reverse_tunnel dialing ws url={tunnel.url} tunnel_id={tunnel.tunnel_id} "
f"agent_id={self.ctx.agent_id} ssl={'yes' if ssl_ctx else 'no'}"
)
try:
tunnel.session = aiohttp.ClientSession(timeout=timeout)
tunnel.websocket = await tunnel.session.ws_connect(
@@ -449,6 +491,7 @@ class Role:
timeout=timeout,
)
self._mark_activity(tunnel)
self._log(f"reverse_tunnel connected ws tunnel_id={tunnel.tunnel_id} peer={getattr(tunnel.websocket, 'remote', None)}")
await tunnel.websocket.send_bytes(
TunnelFrame(
msg_type=MSG_CONNECT,
@@ -466,6 +509,7 @@ class Role:
).encode("utf-8"),
).encode()
)
self._log(f"reverse_tunnel CONNECT sent tunnel_id={tunnel.tunnel_id}")
sender = self.loop.create_task(self._pump_sender(tunnel))
receiver = self.loop.create_task(self._pump_receiver(tunnel))
@@ -486,12 +530,18 @@ class Role:
try:
await tunnel.websocket.send_bytes(frame.encode())
self._mark_activity(tunnel)
self._log(
f"reverse_tunnel send frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
except Exception:
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel sender failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel sender stopped tunnel_id={tunnel.tunnel_id}")
async def _pump_receiver(self, tunnel: ActiveTunnel) -> None:
ws = tunnel.websocket
@@ -508,21 +558,30 @@ class Role:
self._mark_activity(tunnel)
await self._handle_frame(tunnel, frame)
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSE):
self._log(
f"reverse_tunnel websocket closed tunnel_id={tunnel.tunnel_id} "
f"code={ws.close_code} reason={ws.close_reason}"
)
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel receiver failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel receiver stopped tunnel_id={tunnel.tunnel_id}")
async def _heartbeat_loop(self, tunnel: ActiveTunnel) -> None:
try:
while tunnel.websocket and not tunnel.websocket.closed:
await asyncio.sleep(tunnel.heartbeat_seconds)
await self._send_frame(tunnel, heartbeat_frame())
self._log(f"reverse_tunnel heartbeat sent tunnel_id={tunnel.tunnel_id}")
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel heartbeat failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel heartbeat loop stopped tunnel_id={tunnel.tunnel_id}")
async def _watchdog(self, tunnel: ActiveTunnel) -> None:
try:
@@ -531,24 +590,34 @@ class Role:
now = time.time()
if tunnel.idle_seconds and (now - tunnel.last_activity) >= tunnel.idle_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_IDLE_TIMEOUT, "idle_timeout"))
self._log(f"reverse_tunnel watchdog idle_timeout tunnel_id={tunnel.tunnel_id}")
break
if tunnel.expires_at and (now - tunnel.expires_at) >= tunnel.grace_seconds:
await self._send_frame(tunnel, close_frame(0, CLOSE_GRACE_EXPIRED, "grace_expired"))
self._log(f"reverse_tunnel watchdog grace_expired tunnel_id={tunnel.tunnel_id}")
break
except asyncio.CancelledError:
pass
except Exception:
self._log(f"reverse_tunnel watchdog failed tunnel_id={tunnel.tunnel_id}", error=True)
finally:
self._log(f"reverse_tunnel watchdog stopped tunnel_id={tunnel.tunnel_id}")
async def _handle_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
self._log(
f"reverse_tunnel recv frame tunnel_id={tunnel.tunnel_id} "
f"msg_type={frame.msg_type} channel={frame.channel_id} len={len(frame.payload or b'')}"
)
if frame.msg_type == MSG_HEARTBEAT:
if frame.flags & 0x1:
self._log(f"reverse_tunnel heartbeat ack tunnel_id={tunnel.tunnel_id}")
return
await self._send_frame(tunnel, heartbeat_frame(channel_id=frame.channel_id, is_ack=True))
return
if frame.msg_type == MSG_CONNECT_ACK:
tunnel.connected = True
await self._emit_status({"tunnel_id": tunnel.tunnel_id, "agent_id": self.ctx.agent_id, "status": "connected"})
self._log(f"reverse_tunnel CONNECT_ACK tunnel_id={tunnel.tunnel_id}")
return
if frame.msg_type == MSG_CHANNEL_OPEN:
await self._handle_channel_open(tunnel, frame)
@@ -584,6 +653,15 @@ class Role:
await self._send_frame(tunnel, close_frame(frame.channel_id, CLOSE_PROTOCOL_ERROR, "channel_exists"))
return
if protocol.lower() == "ps" and "ps" not in self._protocol_handlers:
hint = f" import_error={PS_IMPORT_ERROR}" if PS_IMPORT_ERROR else ""
module_path = Path(__file__).parent / "ReverseTunnel" / "tunnel_Powershell.py"
exists_hint = f" exists={module_path.exists()}"
self._log(
f"reverse_tunnel ps handler missing; falling back to BaseChannel{hint}{exists_hint}",
error=True,
)
handler_cls = self._protocol_handlers.get(protocol.lower()) or BaseChannel
try:
handler = handler_cls(self, tunnel, frame.channel_id, metadata)
@@ -599,6 +677,10 @@ class Role:
payload=json.dumps({"status": "ok", "protocol": protocol}, separators=(",", ":")).encode("utf-8"),
),
)
self._log(
f"reverse_tunnel channel_opened tunnel_id={tunnel.tunnel_id} channel={frame.channel_id} "
f"protocol={protocol} handler={handler.__class__.__name__} metadata={metadata}"
)
async def _send_frame(self, tunnel: ActiveTunnel, frame: TunnelFrame) -> None:
if tunnel.stopping: