mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-06 16:10:30 -07:00
Removed RDP in favor of VNC / Made WireGuard Tunnel Persistent
This commit is contained in:
285
Data/Engine/services/RemoteDesktop/vnc_proxy.py
Normal file
285
Data/Engine/services/RemoteDesktop/vnc_proxy.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# ======================================================
|
||||
# Data\Engine\services\RemoteDesktop\vnc_proxy.py
|
||||
# Description: VNC tunnel proxy (WebSocket -> TCP) for noVNC sessions.
|
||||
#
|
||||
# API Endpoints (if applicable): None
|
||||
# ======================================================
|
||||
|
||||
"""VNC WebSocket proxy that bridges browser sessions to agent VNC servers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlsplit
|
||||
|
||||
import websockets
|
||||
|
||||
VNC_WS_PATH = "/vnc"
|
||||
_MAX_MESSAGE_SIZE = 100_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class VncSession:
|
||||
token: str
|
||||
agent_id: str
|
||||
host: str
|
||||
port: int
|
||||
created_at: float
|
||||
expires_at: float
|
||||
operator_id: Optional[str] = None
|
||||
|
||||
|
||||
class VncSessionRegistry:
|
||||
def __init__(self, ttl_seconds: int, logger: logging.Logger) -> None:
|
||||
self.ttl_seconds = max(30, int(ttl_seconds))
|
||||
self.logger = logger
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: Dict[str, VncSession] = {}
|
||||
|
||||
def _cleanup(self, now: Optional[float] = None) -> None:
|
||||
current = now if now is not None else time.time()
|
||||
expired = [token for token, session in self._sessions.items() if session.expires_at <= current]
|
||||
for token in expired:
|
||||
self._sessions.pop(token, None)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
host: str,
|
||||
port: int,
|
||||
operator_id: Optional[str] = None,
|
||||
) -> VncSession:
|
||||
token = uuid.uuid4().hex
|
||||
now = time.time()
|
||||
expires_at = now + self.ttl_seconds
|
||||
session = VncSession(
|
||||
token=token,
|
||||
agent_id=agent_id,
|
||||
host=host,
|
||||
port=port,
|
||||
created_at=now,
|
||||
expires_at=expires_at,
|
||||
operator_id=operator_id,
|
||||
)
|
||||
with self._lock:
|
||||
self._cleanup(now)
|
||||
self._sessions[token] = session
|
||||
return session
|
||||
|
||||
def consume(self, token: str) -> Optional[VncSession]:
|
||||
if not token:
|
||||
return None
|
||||
with self._lock:
|
||||
self._cleanup()
|
||||
session = self._sessions.pop(token, None)
|
||||
return session
|
||||
|
||||
def revoke_agent(self, agent_id: str) -> int:
|
||||
if not agent_id:
|
||||
return 0
|
||||
removed = 0
|
||||
with self._lock:
|
||||
self._cleanup()
|
||||
tokens = [token for token, session in self._sessions.items() if session.agent_id == agent_id]
|
||||
for token in tokens:
|
||||
if self._sessions.pop(token, None):
|
||||
removed += 1
|
||||
return removed
|
||||
|
||||
|
||||
class VncProxyServer:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
registry: VncSessionRegistry,
|
||||
logger: logging.Logger,
|
||||
emit_agent_event: Optional[Callable[[str, str, Any], bool]] = None,
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
) -> None:
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.registry = registry
|
||||
self.logger = logger
|
||||
self._emit_agent_event = emit_agent_event
|
||||
self.ssl_context = ssl_context
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready = threading.Event()
|
||||
self._failed = threading.Event()
|
||||
|
||||
def ensure_started(self, timeout: float = 3.0) -> bool:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return not self._failed.is_set()
|
||||
self._failed.clear()
|
||||
self._ready.clear()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._ready.wait(timeout)
|
||||
return not self._failed.is_set()
|
||||
|
||||
def _run(self) -> None:
|
||||
try:
|
||||
asyncio.run(self._serve())
|
||||
except Exception as exc:
|
||||
self._failed.set()
|
||||
self.logger.error("VNC proxy server failed: %s", exc)
|
||||
self._ready.set()
|
||||
|
||||
async def _serve(self) -> None:
|
||||
self.logger.info("Starting VNC proxy on %s:%s", self.host, self.port)
|
||||
try:
|
||||
server = await websockets.serve(
|
||||
self._handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
ssl=self.ssl_context,
|
||||
max_size=_MAX_MESSAGE_SIZE,
|
||||
ping_interval=20,
|
||||
ping_timeout=20,
|
||||
)
|
||||
except Exception:
|
||||
self._failed.set()
|
||||
self._ready.set()
|
||||
raise
|
||||
self._ready.set()
|
||||
await server.wait_closed()
|
||||
|
||||
async def _handle_client(self, websocket, path: str) -> None:
|
||||
parsed = urlsplit(path)
|
||||
if parsed.path != VNC_WS_PATH:
|
||||
await websocket.close(code=1008, reason="invalid_path")
|
||||
return
|
||||
query = parse_qs(parsed.query or "")
|
||||
token = (query.get("token") or [""])[0]
|
||||
session = self.registry.consume(token)
|
||||
if not session:
|
||||
await websocket.close(code=1008, reason="invalid_session")
|
||||
return
|
||||
|
||||
logger = self.logger.getChild("session")
|
||||
logger.info("VNC session start agent_id=%s", session.agent_id)
|
||||
|
||||
try:
|
||||
try:
|
||||
reader, writer = await self._connect_vnc(session.host, session.port)
|
||||
except Exception as exc:
|
||||
logger.warning("VNC connect failed: %s", exc)
|
||||
await websocket.close(code=1011, reason="vnc_unavailable")
|
||||
return
|
||||
|
||||
async def _ws_to_tcp() -> None:
|
||||
try:
|
||||
async for message in websocket:
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, str):
|
||||
data = message.encode("utf-8")
|
||||
else:
|
||||
data = bytes(message)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _tcp_to_ws() -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await reader.read(8192)
|
||||
if not data:
|
||||
break
|
||||
await websocket.send(data)
|
||||
finally:
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.wait(
|
||||
[asyncio.create_task(_ws_to_tcp()), asyncio.create_task(_tcp_to_ws())],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
finally:
|
||||
logger.info("VNC session ended agent_id=%s", session.agent_id)
|
||||
self._notify_agent_session_end(session, reason="vnc_session_end")
|
||||
|
||||
async def _connect_vnc(self, host: str, port: int) -> Tuple[Any, Any]:
|
||||
attempts = 5
|
||||
delay = 0.5
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(attempts):
|
||||
try:
|
||||
return await asyncio.open_connection(host, port)
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < attempts - 1:
|
||||
await asyncio.sleep(delay)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
raise RuntimeError("vnc_connect_failed")
|
||||
|
||||
def _notify_agent_session_end(self, session: VncSession, reason: str) -> None:
|
||||
if not self._emit_agent_event:
|
||||
return
|
||||
payload = {"agent_id": session.agent_id, "reason": reason}
|
||||
try:
|
||||
self._emit_agent_event(session.agent_id, "vnc_stop", payload)
|
||||
except Exception:
|
||||
self.logger.debug("Failed to emit vnc_stop for agent_id=%s", session.agent_id, exc_info=True)
|
||||
|
||||
|
||||
def _build_ssl_context(cert_path: Optional[str], key_path: Optional[str]) -> Optional[ssl.SSLContext]:
|
||||
if not cert_path or not key_path:
|
||||
return None
|
||||
try:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(certfile=cert_path, keyfile=key_path)
|
||||
return context
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ensure_vnc_proxy(context: Any, *, logger: Optional[logging.Logger] = None) -> Optional[VncSessionRegistry]:
|
||||
if logger is None:
|
||||
logger = context.logger if hasattr(context, "logger") else logging.getLogger("borealis.engine.vnc")
|
||||
|
||||
registry = getattr(context, "vnc_registry", None)
|
||||
if registry is None:
|
||||
ttl = int(getattr(context, "vnc_session_ttl_seconds", 120))
|
||||
registry = VncSessionRegistry(ttl_seconds=ttl, logger=logger)
|
||||
setattr(context, "vnc_registry", registry)
|
||||
|
||||
proxy = getattr(context, "vnc_proxy", None)
|
||||
if proxy is None:
|
||||
cert_path = getattr(context, "tls_bundle_path", None) or getattr(context, "tls_cert_path", None)
|
||||
ssl_context = _build_ssl_context(
|
||||
cert_path,
|
||||
getattr(context, "tls_key_path", None),
|
||||
)
|
||||
proxy = VncProxyServer(
|
||||
host=str(getattr(context, "vnc_ws_host", "0.0.0.0")),
|
||||
port=int(getattr(context, "vnc_ws_port", 4823)),
|
||||
registry=registry,
|
||||
logger=logger.getChild("vnc_proxy"),
|
||||
emit_agent_event=getattr(context, "emit_agent_event", None),
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
setattr(context, "vnc_proxy", proxy)
|
||||
|
||||
if not proxy.ensure_started():
|
||||
logger.error("VNC proxy failed to start; VNC sessions unavailable.")
|
||||
return None
|
||||
return registry
|
||||
|
||||
|
||||
__all__ = ["VNC_WS_PATH", "VncSessionRegistry", "VncProxyServer", "ensure_vnc_proxy"]
|
||||
Reference in New Issue
Block a user