mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-16 03:25:48 -07:00
Agent Reverse Tunneling - Agent Role Implementation
This commit is contained in:
@@ -29,6 +29,8 @@ from typing import Callable, Deque, Dict, Iterable, List, Optional, Tuple
|
||||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
from .ReverseTunnel.Powershell import PowershellChannelServer
|
||||
|
||||
try: # websockets is added to engine requirements
|
||||
import websockets
|
||||
from websockets.server import serve as ws_serve
|
||||
@@ -447,6 +449,7 @@ class ReverseTunnelService:
|
||||
self._bridges: Dict[str, "TunnelBridge"] = {}
|
||||
self._port_servers: Dict[int, asyncio.AbstractServer] = {}
|
||||
self._agent_sockets: Dict[str, "websockets.WebSocketServerProtocol"] = {}
|
||||
self._ps_servers: Dict[str, PowershellChannelServer] = {}
|
||||
|
||||
def _ensure_loop(self) -> None:
|
||||
if self._running and self._loop:
|
||||
@@ -548,6 +551,13 @@ class ReverseTunnelService:
|
||||
raise ValueError("unknown_tunnel")
|
||||
bridge = self.ensure_bridge(lease)
|
||||
bridge.attach_operator(operator_id)
|
||||
if lease.domain.lower() == "ps":
|
||||
try:
|
||||
server = self.ensure_ps_server(tunnel_id)
|
||||
if server:
|
||||
server.open_channel()
|
||||
except Exception:
|
||||
self.logger.debug("ps server open failed tunnel_id=%s", tunnel_id, exc_info=True)
|
||||
return bridge
|
||||
|
||||
def _encode_token(self, payload: Dict[str, object]) -> str:
|
||||
@@ -841,6 +851,15 @@ class ReverseTunnelService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _dispatch_agent_frame(self, tunnel_id: str, frame: TunnelFrame) -> None:
|
||||
server = self._ps_servers.get(tunnel_id)
|
||||
if not server:
|
||||
return
|
||||
try:
|
||||
server.handle_agent_frame(frame)
|
||||
except Exception:
|
||||
self.logger.debug("ps handler error for tunnel_id=%s", tunnel_id, exc_info=True)
|
||||
|
||||
def _start_lease_sweeper(self) -> None:
|
||||
async def _sweeper():
|
||||
while self._running and self._loop and not self._loop.is_closed():
|
||||
@@ -928,6 +947,10 @@ class ReverseTunnelService:
|
||||
except Exception:
|
||||
continue
|
||||
self.lease_manager.touch(tunnel_id)
|
||||
try:
|
||||
self._dispatch_agent_frame(tunnel_id, recv_frame)
|
||||
except Exception:
|
||||
pass
|
||||
bridge.agent_to_operator(recv_frame)
|
||||
async def _pump_to_agent():
|
||||
while not websocket.closed:
|
||||
@@ -969,10 +992,30 @@ class ReverseTunnelService:
|
||||
self._bridges[lease.tunnel_id] = bridge
|
||||
return bridge
|
||||
|
||||
def ensure_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
|
||||
server = self._ps_servers.get(tunnel_id)
|
||||
if server:
|
||||
return server
|
||||
lease = self.lease_manager.get(tunnel_id)
|
||||
if lease is None or (lease.domain or "").lower() != "ps":
|
||||
return None
|
||||
bridge = self.ensure_bridge(lease)
|
||||
server = PowershellChannelServer(bridge=bridge, service=self)
|
||||
self._ps_servers[tunnel_id] = server
|
||||
return server
|
||||
|
||||
def get_ps_server(self, tunnel_id: str) -> Optional[PowershellChannelServer]:
|
||||
return self._ps_servers.get(tunnel_id)
|
||||
|
||||
def release_bridge(self, tunnel_id: str, *, reason: str = "bridge_released") -> None:
|
||||
bridge = self._bridges.pop(tunnel_id, None)
|
||||
if bridge:
|
||||
bridge.stop(reason=reason)
|
||||
if tunnel_id in self._ps_servers:
|
||||
try:
|
||||
self._ps_servers.pop(tunnel_id, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TunnelBridge:
|
||||
|
||||
130
Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py
Normal file
130
Data/Engine/services/WebSocket/Agent/ReverseTunnel/Powershell.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Engine-side PowerShell tunnel channel helper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
from ..ReverseTunnel import (
|
||||
CLOSE_AGENT_SHUTDOWN,
|
||||
CLOSE_OK,
|
||||
CLOSE_PROTOCOL_ERROR,
|
||||
MSG_CHANNEL_ACK,
|
||||
MSG_CHANNEL_OPEN,
|
||||
MSG_CLOSE,
|
||||
MSG_CONTROL,
|
||||
MSG_DATA,
|
||||
TunnelFrame,
|
||||
close_frame,
|
||||
)
|
||||
|
||||
|
||||
class PowershellChannelServer:
|
||||
"""Coordinate PowerShell channel frames over a TunnelBridge."""
|
||||
|
||||
def __init__(self, bridge, service, *, channel_id: int = 1):
|
||||
self.bridge = bridge
|
||||
self.service = service
|
||||
self.channel_id = channel_id
|
||||
self.logger = service.logger.getChild(f"ps.{bridge.lease.tunnel_id}")
|
||||
self._open_sent = False
|
||||
self._ack_received = False
|
||||
self._closed = False
|
||||
self._output: Deque[str] = deque()
|
||||
self._close_reason: Optional[str] = None
|
||||
self._close_code: Optional[int] = None
|
||||
|
||||
# ------------------------------------------------------------------ Agent frame handling
|
||||
def handle_agent_frame(self, frame: TunnelFrame) -> None:
|
||||
if frame.channel_id != self.channel_id:
|
||||
return
|
||||
if frame.msg_type == MSG_CHANNEL_ACK:
|
||||
self._ack_received = True
|
||||
self.logger.info("ps channel acked tunnel_id=%s", self.bridge.lease.tunnel_id)
|
||||
return
|
||||
if frame.msg_type == MSG_DATA:
|
||||
try:
|
||||
text = frame.payload.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
text = ""
|
||||
if text:
|
||||
self._append_output(text)
|
||||
return
|
||||
if frame.msg_type == MSG_CLOSE:
|
||||
try:
|
||||
payload = json.loads(frame.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
payload = {}
|
||||
self._closed = True
|
||||
self._close_code = payload.get("code") if isinstance(payload, dict) else None
|
||||
self._close_reason = payload.get("reason") if isinstance(payload, dict) else None
|
||||
self.logger.info(
|
||||
"ps channel closed tunnel_id=%s code=%s reason=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self._close_code,
|
||||
self._close_reason or "-",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ Operator actions
|
||||
def open_channel(self, *, cols: int = 120, rows: int = 32) -> None:
|
||||
if self._open_sent:
|
||||
return
|
||||
payload = json.dumps(
|
||||
{"protocol": "ps", "metadata": {"cols": cols, "rows": rows}},
|
||||
separators=(",", ":"),
|
||||
).encode("utf-8")
|
||||
frame = TunnelFrame(msg_type=MSG_CHANNEL_OPEN, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
self._open_sent = True
|
||||
self.logger.info(
|
||||
"ps channel open sent tunnel_id=%s channel_id=%s cols=%s rows=%s",
|
||||
self.bridge.lease.tunnel_id,
|
||||
self.channel_id,
|
||||
cols,
|
||||
rows,
|
||||
)
|
||||
|
||||
def send_input(self, data: str) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = data.encode("utf-8", errors="replace")
|
||||
frame = TunnelFrame(msg_type=MSG_DATA, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def send_resize(self, cols: int, rows: int) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
payload = json.dumps({"cols": cols, "rows": rows}, separators=(",", ":")).encode("utf-8")
|
||||
frame = TunnelFrame(msg_type=MSG_CONTROL, channel_id=self.channel_id, payload=payload)
|
||||
self.bridge.operator_to_agent(frame)
|
||||
|
||||
def close(self, code: int = CLOSE_AGENT_SHUTDOWN, reason: str = "operator_close") -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
self.bridge.operator_to_agent(close_frame(self.channel_id, code, reason))
|
||||
|
||||
# ------------------------------------------------------------------ Output polling
|
||||
def drain_output(self) -> List[str]:
|
||||
items: List[str] = []
|
||||
while self._output:
|
||||
items.append(self._output.popleft())
|
||||
return items
|
||||
|
||||
def _append_output(self, text: str) -> None:
|
||||
self._output.append(text)
|
||||
# Cap buffer to avoid unbounded memory growth.
|
||||
while len(self._output) > 500:
|
||||
self._output.popleft()
|
||||
|
||||
# ------------------------------------------------------------------ Status helpers
|
||||
def status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"channel_id": self.channel_id,
|
||||
"open_sent": self._open_sent,
|
||||
"ack": self._ack_received,
|
||||
"closed": self._closed,
|
||||
"close_reason": self._close_reason,
|
||||
"close_code": self._close_code,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
"""Protocol-specific helpers for Reverse Tunnel (Engine side)."""
|
||||
|
||||
@@ -401,6 +401,91 @@ def register_realtime(socket_server: SocketIO, context: EngineContext) -> None:
|
||||
frames.append(_encode_frame(frame))
|
||||
return {"frames": frames}
|
||||
|
||||
def _require_ps_server():
|
||||
sid = request.sid
|
||||
tunnel_id = _operator_sessions.get(sid)
|
||||
if not tunnel_id:
|
||||
return None, None, {"error": "not_joined"}
|
||||
server = tunnel_service.ensure_ps_server(tunnel_id)
|
||||
if server is None:
|
||||
return None, tunnel_id, {"error": "ps_unsupported"}
|
||||
return server, tunnel_id, None
|
||||
|
||||
@socket_server.on("ps_open", namespace=tunnel_namespace)
|
||||
def _ws_ps_open(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = 120
|
||||
rows = 32
|
||||
if isinstance(data, dict):
|
||||
try:
|
||||
cols = int(data.get("cols", cols))
|
||||
rows = int(data.get("rows", rows))
|
||||
except Exception:
|
||||
pass
|
||||
cols = max(20, min(cols, 300))
|
||||
rows = max(10, min(rows, 200))
|
||||
try:
|
||||
server.open_channel(cols=cols, rows=rows)
|
||||
except Exception as exc:
|
||||
logger.debug("ps_open failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_open_failed"}
|
||||
return {"status": "ok", "tunnel_id": tunnel_id, "cols": cols, "rows": rows}
|
||||
|
||||
@socket_server.on("ps_send", namespace=tunnel_namespace)
|
||||
def _ws_ps_send(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
if data is None:
|
||||
return {"error": "payload_required"}
|
||||
text = data
|
||||
if isinstance(data, dict):
|
||||
text = data.get("data")
|
||||
if text is None:
|
||||
return {"error": "payload_required"}
|
||||
try:
|
||||
server.send_input(str(text))
|
||||
except Exception as exc:
|
||||
logger.debug("ps_send failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_send_failed"}
|
||||
return {"status": "ok"}
|
||||
|
||||
@socket_server.on("ps_resize", namespace=tunnel_namespace)
|
||||
def _ws_ps_resize(data: Any) -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
cols = None
|
||||
rows = None
|
||||
if isinstance(data, dict):
|
||||
cols = data.get("cols")
|
||||
rows = data.get("rows")
|
||||
try:
|
||||
cols_int = int(cols) if cols is not None else 120
|
||||
rows_int = int(rows) if rows is not None else 32
|
||||
cols_int = max(20, min(cols_int, 300))
|
||||
rows_int = max(10, min(rows_int, 200))
|
||||
server.send_resize(cols_int, rows_int)
|
||||
return {"status": "ok", "cols": cols_int, "rows": rows_int}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_resize failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_resize_failed"}
|
||||
|
||||
@socket_server.on("ps_poll", namespace=tunnel_namespace)
|
||||
def _ws_ps_poll() -> Any:
|
||||
server, tunnel_id, error = _require_ps_server()
|
||||
if server is None:
|
||||
return error
|
||||
try:
|
||||
output = server.drain_output()
|
||||
status = server.status()
|
||||
return {"output": output, "status": status}
|
||||
except Exception as exc:
|
||||
logger.debug("ps_poll failed tunnel_id=%s: %s", tunnel_id, exc, exc_info=True)
|
||||
return {"error": "ps_poll_failed"}
|
||||
|
||||
@socket_server.on("disconnect", namespace=tunnel_namespace)
|
||||
def _ws_tunnel_disconnect():
|
||||
sid = request.sid
|
||||
|
||||
Reference in New Issue
Block a user