Files
Borealis-Github-Replica/Data/Engine/services/RemoteDesktop/vnc_proxy.py

290 lines
9.8 KiB
Python

# ======================================================
# Data\Engine\services\RemoteDesktop\vnc_proxy.py
# Description: VNC tunnel proxy (WebSocket -> TCP) for noVNC sessions.
#
# API Endpoints (if applicable): None
# ======================================================
"""VNC WebSocket proxy that bridges browser sessions to agent VNC servers."""
from __future__ import annotations
import asyncio
import logging
import ssl
import threading
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple
from urllib.parse import parse_qs, urlsplit
import websockets
VNC_WS_PATH = "/vnc"
_MAX_MESSAGE_SIZE = 100_000_000
@dataclass
class VncSession:
token: str
agent_id: str
host: str
port: int
created_at: float
expires_at: float
operator_id: Optional[str] = None
class VncSessionRegistry:
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
self.ttl_seconds = max(30, int(ttl_seconds))
self.logger = logger
self._lock = threading.Lock()
self._sessions: Dict[str, VncSession] = {}
def _cleanup(self, now: Optional[float] = None) -> None:
current = now if now is not None else time.time()
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
for token in expired:
self._sessions.pop(token, None)
def create(
self,
*,
agent_id: str,
host: str,
port: int,
operator_id: Optional[str] = None,
) -> VncSession:
token = uuid.uuid4().hex
now = time.time()
expires_at = now + self.ttl_seconds
session = VncSession(
token=token,
agent_id=agent_id,
host=host,
port=port,
created_at=now,
expires_at=expires_at,
operator_id=operator_id,
)
with self._lock:
self._cleanup(now)
self._sessions[token] = session
return session
def consume(self, token: str) -> Optional[VncSession]:
if not token:
return None
with self._lock:
self._cleanup()
session = self._sessions.pop(token, None)
return session
def revoke_agent(self, agent_id: str) -> int:
if not agent_id:
return 0
removed = 0
with self._lock:
self._cleanup()
tokens = [token for token, session in self._sessions.items() if session.agent_id == agent_id]
for token in tokens:
if self._sessions.pop(token, None):
removed += 1
return removed
class VncProxyServer:
def __init__(
self,
*,
host: str,
port: int,
registry: VncSessionRegistry,
logger: logging.Logger,
emit_agent_event: Optional[Callable[[str, str, Any], bool]] = None,
ssl_context: Optional[ssl.SSLContext] = None,
) -> None:
self.host = host
self.port = port
self.registry = registry
self.logger = logger
self._emit_agent_event = emit_agent_event
self.ssl_context = ssl_context
self._thread: Optional[threading.Thread] = None
self._ready = threading.Event()
self._failed = threading.Event()
def ensure_started(self, timeout: float = 3.0) -> bool:
if self._thread and self._thread.is_alive():
return not self._failed.is_set()
self._failed.clear()
self._ready.clear()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._ready.wait(timeout)
return not self._failed.is_set()
def _run(self) -> None:
try:
asyncio.run(self._serve())
except Exception as exc:
self._failed.set()
self.logger.error("VNC proxy server failed: %s", exc)
self._ready.set()
async def _serve(self) -> None:
self.logger.info("Starting VNC proxy on %s:%s", self.host, self.port)
try:
server = await websockets.serve(
self._handle_client,
self.host,
self.port,
ssl=self.ssl_context,
max_size=_MAX_MESSAGE_SIZE,
ping_interval=20,
ping_timeout=20,
)
except Exception:
self._failed.set()
self._ready.set()
raise
self._ready.set()
await server.wait_closed()
async def _handle_client(self, websocket, path: Optional[str] = None) -> None:
raw_path = path or getattr(websocket, "path", "") or ""
parsed = urlsplit(raw_path)
if parsed.path != VNC_WS_PATH:
self.logger.warning("VNC proxy rejected request with invalid path: %s", raw_path)
await websocket.close(code=1008, reason="invalid_path")
return
query = parse_qs(parsed.query or "")
token = (query.get("token") or [""])[0]
session = self.registry.consume(token)
if not session:
token_hint = token[:8] if token else "-"
self.logger.warning("VNC proxy rejected session (token=%s)", token_hint)
await websocket.close(code=1008, reason="invalid_session")
return
logger = self.logger.getChild("session")
logger.info("VNC session start agent_id=%s", session.agent_id)
try:
try:
reader, writer = await self._connect_vnc(session.host, session.port)
except Exception as exc:
logger.warning("VNC connect failed: %s", exc)
await websocket.close(code=1011, reason="vnc_unavailable")
return
async def _ws_to_tcp() -> None:
try:
async for message in websocket:
if message is None:
break
if isinstance(message, str):
data = message.encode("utf-8")
else:
data = bytes(message)
writer.write(data)
await writer.drain()
finally:
try:
writer.close()
except Exception:
pass
async def _tcp_to_ws() -> None:
try:
while True:
data = await reader.read(8192)
if not data:
break
await websocket.send(data)
finally:
try:
await websocket.close()
except Exception:
pass
await asyncio.wait(
[asyncio.create_task(_ws_to_tcp()), asyncio.create_task(_tcp_to_ws())],
return_when=asyncio.FIRST_COMPLETED,
)
finally:
logger.info("VNC session ended agent_id=%s", session.agent_id)
self._notify_agent_session_end(session, reason="vnc_session_end")
async def _connect_vnc(self, host: str, port: int) -> Tuple[Any, Any]:
attempts = 5
delay = 0.5
last_exc: Optional[Exception] = None
for attempt in range(attempts):
try:
return await asyncio.open_connection(host, port)
except Exception as exc:
last_exc = exc
if attempt < attempts - 1:
await asyncio.sleep(delay)
if last_exc:
raise last_exc
raise RuntimeError("vnc_connect_failed")
def _notify_agent_session_end(self, session: VncSession, reason: str) -> None:
if not self._emit_agent_event:
return
payload = {"agent_id": session.agent_id, "reason": reason}
try:
self._emit_agent_event(session.agent_id, "vnc_stop", payload)
except Exception:
self.logger.debug("Failed to emit vnc_stop for agent_id=%s", session.agent_id, exc_info=True)
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
if not cert_path or not key_path:
return None
try:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
return context
except Exception:
return None
def ensure_vnc_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[VncSessionRegistry]:
if logger is None:
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.vnc")
registry = getattr(context, "vnc_registry", None)
if registry is None:
ttl = int(getattr(context, "vnc_session_ttl_seconds", 120))
registry = VncSessionRegistry(ttl_seconds=ttl, logger=logger)
setattr(context, "vnc_registry", registry)
proxy = getattr(context, "vnc_proxy", None)
if proxy is None:
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
ssl_context = _build_ssl_context(
cert_path,
getattr(context, "tls_key_path", None),
)
proxy = VncProxyServer(
host=str(getattr(context, "vnc_ws_host", "0.0.0.0")),
port=int(getattr(context, "vnc_ws_port", 4823)),
registry=registry,
logger=logger.getChild("vnc_proxy"),
emit_agent_event=getattr(context, "emit_agent_event", None),
ssl_context=ssl_context,
)
setattr(context, "vnc_proxy", proxy)
if not proxy.ensure_started():
logger.error("VNC proxy failed to start; VNC sessions unavailable.")
return None
return registry
__all__ = ["VNC_WS_PATH", "VncSessionRegistry", "VncProxyServer", "ensure_vnc_proxy"]