Major Progress Towards Interactive Remote Powershell

This commit is contained in:
2025-12-06 00:27:57 -07:00
parent 52e40c3753
commit 68dd46347b
9 changed files with 1247 additions and 53 deletions

View File

@@ -391,11 +391,13 @@ class TunnelLeaseManager:
def expire_idle(self, *, now_ts: Optional[float] = None) -> List[TunnelLease]:
now = now_ts or _utc_ts()
expired: List[TunnelLease] = []
pending_timeout = min(self.grace_timeout_seconds, 300) # avoid long-lived pending locks
for lease in list(self._leases.values()):
if lease.state == "expired":
continue
idle_age = now - lease.last_activity_ts
pending_age = now - lease.created_at
if lease.state == "active" and idle_age >= lease.idle_timeout_seconds:
lease.mark_expired()
expired.append(lease)
@@ -409,6 +411,14 @@ class TunnelLeaseManager:
expired.append(lease)
self.release(lease.tunnel_id, reason="grace_expired")
continue
if lease.state == "pending":
hard_expiry = lease.expires_at or (lease.created_at + lease.grace_timeout_seconds)
if pending_age >= pending_timeout or (hard_expiry and now >= hard_expiry):
lease.mark_expired()
expired.append(lease)
self.release(lease.tunnel_id, reason="pending_timeout")
continue
return expired
def all_leases(self) -> Iterable[TunnelLease]:
@@ -591,6 +601,7 @@ class ReverseTunnelService:
)
lease.token = self.issue_token(lease)
self._spawn_port_listener(lease.assigned_port)
self._push_start_to_agent(lease)
self.audit_logger.info(
"lease_created tunnel_id=%s agent_id=%s domain=%s protocol=%s port=%s operator=%s",
lease.tunnel_id,
@@ -618,6 +629,35 @@ class ReverseTunnelService:
lease.expires_at = expires_at
return token
def _push_start_to_agent(self, lease: TunnelLease) -> None:
"""Notify the target agent about the new lease over Socket.IO (best-effort)."""
if not self._socketio:
return
payload = {
"tunnel_id": lease.tunnel_id,
"lease_id": lease.tunnel_id,
"agent_id": lease.agent_id,
"token": lease.token,
"port": lease.assigned_port,
"assigned_port": lease.assigned_port,
"protocol": lease.protocol,
"domain": lease.domain,
"idle_seconds": lease.idle_timeout_seconds,
"grace_seconds": lease.grace_timeout_seconds,
"heartbeat_seconds": self.heartbeat_seconds,
}
try:
self._socketio.emit("reverse_tunnel_start", payload, namespace="/")
self.audit_logger.info(
"lease_push_start tunnel_id=%s agent_id=%s port=%s",
lease.tunnel_id,
lease.agent_id,
lease.assigned_port,
)
except Exception:
self.logger.debug("Failed to emit reverse_tunnel_start for tunnel_id=%s", lease.tunnel_id, exc_info=True)
def lease_summary(self, lease: TunnelLease) -> Dict[str, object]:
return {
"tunnel_id": lease.tunnel_id,
@@ -874,6 +914,7 @@ class ReverseTunnelService:
cert = self.context.tls_cert_path or self.context.tls_bundle_path
key = self.context.tls_key_path
if not cert or not key:
self.audit_logger.info("tunnel_listener_ssl_missing cert=%s key=%s", cert, key)
return None
try:
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
@@ -890,6 +931,7 @@ class ReverseTunnelService:
if port in self._port_servers:
return
ssl_ctx = self._build_ssl_context()
self.audit_logger.info("tunnel_listener_start port=%s ssl=%s", port, bool(ssl_ctx))
async def _handler(websocket, path):
await self._handle_agent_socket(websocket, path, port=port)
@@ -897,6 +939,7 @@ class ReverseTunnelService:
async def _start():
server = await ws_serve(_handler, host="0.0.0.0", port=port, ssl=ssl_ctx, max_size=None, ping_interval=None)
self._port_servers[port] = server
self.audit_logger.info("tunnel_listener_bound port=%s", port)
asyncio.run_coroutine_threadsafe(_start(), self._loop)
@@ -904,15 +947,24 @@ class ReverseTunnelService:
"""Handle agent tunnel socket on assigned port."""
tunnel_id = None
sock_log = self.audit_logger.getChild("agent_socket")
try:
peer = None
try:
peer = getattr(websocket, "remote_address", None)
except Exception:
peer = None
sock_log.info("agent_socket_open port=%s path=%s peer=%s", port, path, peer)
raw = await asyncio.wait_for(websocket.recv(), timeout=10)
frame = decode_frame(raw)
if frame.msg_type != MSG_CONNECT:
sock_log.info("agent_socket_first_frame_not_connect port=%s msg_type=%s", port, frame.msg_type)
await websocket.close()
return
try:
payload = json.loads(frame.payload.decode("utf-8"))
except Exception:
sock_log.info("agent_socket_connect_payload_decode_failed port=%s", port, exc_info=True)
await websocket.close()
return
tunnel_id = str(payload.get("tunnel_id") or "").strip()
@@ -920,23 +972,56 @@ class ReverseTunnelService:
token = payload.get("token") or ""
lease = self.lease_manager.get(tunnel_id)
if lease is None or lease.assigned_port != port:
sock_log.info(
"agent_socket_unknown_lease port=%s tunnel_id=%s assigned=%s expected=%s",
port,
tunnel_id,
lease.assigned_port if lease else None,
port,
)
await websocket.close()
return
# Token validation
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
try:
self.validate_token(
token,
agent_id=agent_id,
tunnel_id=tunnel_id,
domain=lease.domain,
protocol=lease.protocol,
)
sock_log.info(
"agent_socket_token_valid port=%s tunnel_id=%s agent_id=%s domain=%s protocol=%s",
port,
tunnel_id,
agent_id,
lease.domain,
lease.protocol,
)
except Exception as exc:
sock_log.info(
"agent_socket_token_invalid port=%s tunnel_id=%s agent_id=%s error=%s",
port,
tunnel_id,
agent_id,
exc,
)
await websocket.close()
return
bridge = self.ensure_bridge(lease)
bridge.attach_agent(token)
self._agent_sockets[tunnel_id] = websocket
await websocket.send(heartbeat_frame(channel_id=0, is_ack=True).encode())
await websocket.send(TunnelFrame(msg_type=MSG_CONNECT_ACK, channel_id=0, payload=b"").encode())
sock_log.info(
"agent_socket_connected port=%s tunnel_id=%s agent_id=%s",
port,
tunnel_id,
agent_id,
)
async def _pump_to_operator():
sock_log_local = sock_log.getChild("recv")
while not websocket.closed:
try:
raw_msg = await websocket.recv()
@@ -945,14 +1030,23 @@ class ReverseTunnelService:
try:
recv_frame = decode_frame(raw_msg)
except Exception:
sock_log_local.info("agent_socket_frame_decode_failed tunnel_id=%s", tunnel_id, exc_info=True)
continue
self.lease_manager.touch(tunnel_id)
sock_log_local.info(
"agent_to_operator tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
recv_frame.msg_type,
recv_frame.channel_id,
len(recv_frame.payload or b""),
)
try:
self._dispatch_agent_frame(tunnel_id, recv_frame)
except Exception:
pass
bridge.agent_to_operator(recv_frame)
async def _pump_to_agent():
sock_log_local = sock_log.getChild("send")
while not websocket.closed:
frame = bridge.next_for_agent()
if frame is None:
@@ -960,13 +1054,22 @@ class ReverseTunnelService:
continue
try:
await websocket.send(frame.encode())
sock_log_local.info(
"operator_to_agent tunnel_id=%s msg_type=%s channel=%s payload_len=%s",
tunnel_id,
frame.msg_type,
frame.channel_id,
len(frame.payload or b""),
)
except Exception:
break
async def _heartbeat():
sock_log_local = sock_log.getChild("heartbeat")
while not websocket.closed:
try:
await websocket.send(heartbeat_frame(channel_id=0).encode())
except Exception:
sock_log_local.info("heartbeat_send_failed tunnel_id=%s", tunnel_id, exc_info=True)
break
await asyncio.sleep(self.heartbeat_seconds)
@@ -975,8 +1078,18 @@ class ReverseTunnelService:
heart = asyncio.create_task(_heartbeat())
await asyncio.wait([consumer, producer, heart], return_when=asyncio.FIRST_COMPLETED)
except Exception:
self.logger.debug("Agent socket handler failed on port %s", port, exc_info=True)
sock_log.info("agent_socket_handler_failed port=%s tunnel_id=%s", port, tunnel_id, exc_info=True)
finally:
try:
sock_log.info(
"agent_socket_closed port=%s tunnel_id=%s code=%s reason=%s",
port,
tunnel_id,
getattr(websocket, "close_code", None),
getattr(websocket, "close_reason", None),
)
except Exception:
pass
if tunnel_id and tunnel_id in self._agent_sockets:
self._agent_sockets.pop(tunnel_id, None)
if tunnel_id:
@@ -997,7 +1110,7 @@ class ReverseTunnelService:
if server:
return server
lease = self.lease_manager.get(tunnel_id)
if lease is None or (lease.domain or "").lower() != "ps":
if lease is None:
return None
bridge = self.ensure_bridge(lease)
server = PowershellChannelServer(