Agent Reverse Tunneling - Agent Role Implementation

This commit is contained in:
2025-12-01 02:40:23 -07:00
parent 478237f487
commit fe4511ecaf
8 changed files with 1155 additions and 11 deletions

View File

@@ -29,6 +29,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
from collections import deque
from threading import Thread
from .ReverseTunnel.Powershell import PowershellChannelServer
try: # websockets is added to engine requirements
import websockets
from websockets.server import serve as ws_serve
@@ -447,6 +449,7 @@ class ReverseTunnelService:
self._bridges: Dict[str, "TunnelBridge"] = {}
self._port_servers: Dict[int, asyncio.AbstractServer] = {}
self._agent_sockets: Dict[str, "websockets.WebSocketServerProtocol"] = {}
self._ps_servers: Dict[str, PowershellChannelServer] = {}
def _ensure_loop(self) -> None:
if self._running and self._loop:
@@ -548,6 +551,13 @@ class ReverseTunnelService:
raise ValueError("unknown_tunnel")
bridge = self.ensure_bridge(lease)
bridge.attach_operator(operator_id)
if lease.domain.lower() == "ps":
try:
server = self.ensure_ps_server(tunnel_id)
if server:
server.open_channel()
except Exception:
self.logger.debug("ps server open failed tunnel_id=%s", tunnel_id, exc_info=True)
return bridge
def _encode_token(self, payload: Dict[str, object]) -> str:
@@ -841,6 +851,15 @@ class ReverseTunnelService:
except Exception:
pass
def _dispatch_agent_frame(self, tunnel_id: str, frame: TunnelFrame) -> None:
server = self._ps_servers.get(tunnel_id)
if not server:
return
try:
server.handle_agent_frame(frame)
except Exception:
self.logger.debug("ps handler error for tunnel_id=%s", tunnel_id, exc_info=True)
def _start_lease_sweeper(self) -> None:
async def _sweeper():
while self._running and self._loop and not self._loop.is_closed():
@@ -928,6 +947,10 @@ class ReverseTunnelService:
except Exception:
continue
self.lease_manager.touch(tunnel_id)
try:
self._dispatch_agent_frame(tunnel_id, recv_frame)
except Exception:
pass
bridge.agent_to_operator(recv_frame)
async def _pump_to_agent():
while not websocket.closed:
@@ -969,10 +992,30 @@ class ReverseTunnelService:
self._bridges[lease.tunnel_id] = bridge
return bridge
def ensure_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
server = self._ps_servers.get(tunnel_id)
if server:
return server
lease = self.lease_manager.get(tunnel_id)
if lease is None or (lease.domain or "").lower() != "ps":
return None
bridge = self.ensure_bridge(lease)
server = PowershellChannelServer(bridge=bridge, service=self)
self._ps_servers[tunnel_id] = server
return server
def get_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
return self._ps_servers.get(tunnel_id)
def release_bridge(self, tunnel_id: str, *, reason: str = "bridge_released") -> None:
bridge = self._bridges.pop(tunnel_id, None)
if bridge:
bridge.stop(reason=reason)
if tunnel_id in self._ps_servers:
try:
self._ps_servers.pop(tunnel_id, None)
except Exception:
pass
class TunnelBridge:

View File

@@ -0,0 +1,130 @@
"""Engine-side PowerShell tunnel channel helper."""
from __future__ import annotations
import json
from collections import deque
from typing import Any, Deque, Dict, List, Optional
from ..ReverseTunnel import (
CLOSE_AGENT_SHUTDOWN,
CLOSE_OK,
CLOSE_PROTOCOL_ERROR,
MSG_CHANNEL_ACK,
MSG_CHANNEL_OPEN,
MSG_CLOSE,
MSG_CONTROL,
MSG_DATA,
TunnelFrame,
close_frame,
)
class PowershellChannelServer:
"""Coordinate PowerShell channel frames over a TunnelBridge."""
def __init__(self, bridge, service, *, channel_id: int = 1):
self.bridge = bridge
self.service = service
self.channel_id = channel_id
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
self._open_sent = False
self._ack_received = False
self._closed = False
self._output: Deque[str] = deque()
self._close_reason: Optional[str] = None
self._close_code: Optional[int] = None
# ------------------------------------------------------------------ Agent frame handling
def handle_agent_frame(self, frame: TunnelFrame) -> None:
if frame.channel_id != self.channel_id:
return
if frame.msg_type == MSG_CHANNEL_ACK:
self._ack_received = True
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
return
if frame.msg_type == MSG_DATA:
try:
text = frame.payload.decode("utf-8", errors="replace")
except Exception:
text = ""
if text:
self._append_output(text)
return
if frame.msg_type == MSG_CLOSE:
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
payload = {}
self._closed = True
self._close_code = payload.get("code") if isinstance(payload, dict) else None
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
self.logger.info(
"ps channel closed tunnel_id=%s code=%s reason=%s",
self.bridge.lease.tunnel_id,
self._close_code,
self._close_reason or "-",
)
# ------------------------------------------------------------------ Operator actions
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
if self._open_sent:
return
payload = json.dumps(
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
separators=(",", ":"),
).encode("utf-8")
frame = TunnelFrame(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
self._open_sent = True
self.logger.info(
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
self.bridge.lease.tunnel_id,
self.channel_id,
cols,
rows,
)
def send_input(self, data: str) -> None:
if self._closed:
return
payload = data.encode("utf-8", errors="replace")
frame = TunnelFrame(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def send_resize(self, cols: int, rows: int) -> None:
if self._closed:
return
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
frame = TunnelFrame(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
self.bridge.operator_to_agent(frame)
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
if self._closed:
return
self._closed = True
self.bridge.operator_to_agent(close_frame(self.channel_id, code, reason))
# ------------------------------------------------------------------ Output polling
def drain_output(self) -> List[str]:
items: List[str] = []
while self._output:
items.append(self._output.popleft())
return items
def _append_output(self, text: str) -> None:
self._output.append(text)
# Cap buffer to avoid unbounded memory growth.
while len(self._output) > 500:
self._output.popleft()
# ------------------------------------------------------------------ Status helpers
def status(self) -> Dict[str, Any]:
return {
"channel_id": self.channel_id,
"open_sent": self._open_sent,
"ack": self._ack_received,
"closed": self._closed,
"close_reason": self._close_reason,
"close_code": self._close_code,
}

View File

@@ -0,0 +1,2 @@
"""Protocol-specific helpers for Reverse Tunnel (Engine side)."""

View File

@@ -401,6 +401,91 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
frames.append(_encode_frame(frame))
return {"frames": frames}
def _require_ps_server():
sid = request.sid
tunnel_id = _operator_sessions.get(sid)
if not tunnel_id:
return None, None, {"error": "not_joined"}
server = tunnel_service.ensure_ps_server(tunnel_id)
if server is None:
return None, tunnel_id, {"error": "ps_unsupported"}
return server, tunnel_id, None
@socket_server.on("ps_open", namespace=tunnel_namespace)
def _ws_ps_open(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = 120
rows = 32
if isinstance(data, dict):
try:
cols = int(data.get("cols", cols))
rows = int(data.get("rows", rows))
except Exception:
pass
cols = max(20, min(cols, 300))
rows = max(10, min(rows, 200))
try:
server.open_channel(cols=cols, rows=rows)
except Exception as exc:
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_open_failed"}
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
@socket_server.on("ps_send", namespace=tunnel_namespace)
def _ws_ps_send(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
if data is None:
return {"error": "payload_required"}
text = data
if isinstance(data, dict):
text = data.get("data")
if text is None:
return {"error": "payload_required"}
try:
server.send_input(str(text))
except Exception as exc:
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_send_failed"}
return {"status": "ok"}
@socket_server.on("ps_resize", namespace=tunnel_namespace)
def _ws_ps_resize(data: Any) -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
cols = None
rows = None
if isinstance(data, dict):
cols = data.get("cols")
rows = data.get("rows")
try:
cols_int = int(cols) if cols is not None else 120
rows_int = int(rows) if rows is not None else 32
cols_int = max(20, min(cols_int, 300))
rows_int = max(10, min(rows_int, 200))
server.send_resize(cols_int, rows_int)
return {"status": "ok", "cols": cols_int, "rows": rows_int}
except Exception as exc:
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_resize_failed"}
@socket_server.on("ps_poll", namespace=tunnel_namespace)
def _ws_ps_poll() -> Any:
server, tunnel_id, error = _require_ps_server()
if server is None:
return error
try:
output = server.drain_output()
status = server.status()
return {"output": output, "status": status}
except Exception as exc:
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
return {"error": "ps_poll_failed"}
@socket_server.on("disconnect", namespace=tunnel_namespace)
def _ws_tunnel_disconnect():
sid = request.sid