mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2026-02-07 01:20:31 -07:00
Removed RDP in favor of VNC / Made WireGuard Tunnel Persistent
This commit is contained in:
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import base64
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -22,6 +23,13 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
||||
from .wireguard_server import WireGuardServerManager
|
||||
|
||||
|
||||
def _env_flag(name: str, *, default: bool) -> bool:
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
return default
|
||||
return str(value).strip().lower() not in ("0", "false", "no", "off")
|
||||
|
||||
|
||||
@dataclass
|
||||
class VpnSession:
|
||||
tunnel_id: str
|
||||
@@ -62,14 +70,17 @@ class VpnTunnelService:
|
||||
self.logger = context.logger.getChild("vpn_tunnel")
|
||||
self.activity_logger = self.wg.logger.getChild("device_activity")
|
||||
self.idle_seconds = max(60, int(idle_seconds))
|
||||
self.persistent = _env_flag("BOREALIS_WIREGUARD_PERSISTENT", default=True)
|
||||
self._lock = threading.Lock()
|
||||
self._sessions_by_agent: Dict[str, VpnSession] = {}
|
||||
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
|
||||
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
|
||||
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
|
||||
self._cleanup_listener()
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
self._idle_thread: Optional[threading.Thread] = None
|
||||
if not self.persistent:
|
||||
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
||||
self._idle_thread.start()
|
||||
|
||||
def _idle_loop(self) -> None:
|
||||
while True:
|
||||
@@ -90,7 +101,7 @@ class VpnTunnelService:
|
||||
self.idle_seconds,
|
||||
)
|
||||
)
|
||||
self.disconnect(session.agent_id, reason="idle_timeout")
|
||||
self.disconnect(session.agent_id, reason="idle_timeout", force=True)
|
||||
|
||||
def _allocate_virtual_ip(self, agent_id: str) -> str:
|
||||
existing = self._sessions_by_agent.get(agent_id)
|
||||
@@ -226,12 +237,12 @@ class VpnTunnelService:
|
||||
self.logger.debug("Failed to write vpn_tunnel service log entry", exc_info=True)
|
||||
|
||||
def _cleanup_listener(self) -> None:
|
||||
try:
|
||||
self.wg.stop_listener(ignore_missing=True)
|
||||
self._service_log_event("vpn_listener_cleanup reason=startup")
|
||||
except Exception:
|
||||
self.logger.debug("Failed to clean up WireGuard listener on startup.", exc_info=True)
|
||||
self._service_log_event("vpn_listener_cleanup_failed reason=startup", level="WARNING")
|
||||
self._service_log_event("vpn_listener_cleanup_skipped reason=startup")
|
||||
|
||||
def _is_soft_disconnect(self, reason: Optional[str]) -> bool:
|
||||
if not reason:
|
||||
return False
|
||||
return str(reason).lower() in ("operator_disconnect", "component_unmount")
|
||||
|
||||
def _refresh_listener(self) -> None:
|
||||
peers: List[Mapping[str, object]] = []
|
||||
@@ -432,15 +443,60 @@ class VpnTunnelService:
|
||||
except Exception:
|
||||
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
|
||||
|
||||
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
|
||||
def disconnect(
|
||||
self,
|
||||
agent_id: str,
|
||||
reason: str = "operator_stop",
|
||||
*,
|
||||
operator_id: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
session = self._sessions_by_agent.get(agent_id)
|
||||
if not session:
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_disconnect_missing agent_id={0} reason={1}".format(agent_id or "-", reason or "-")
|
||||
)
|
||||
return False
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
if self.persistent and not force:
|
||||
if operator_id:
|
||||
try:
|
||||
session.operator_ids.discard(operator_id)
|
||||
except Exception:
|
||||
pass
|
||||
session.last_activity = time.time()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason or "-",
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
return True
|
||||
if not force and self._is_soft_disconnect(reason):
|
||||
if operator_id:
|
||||
try:
|
||||
session.operator_ids.discard(operator_id)
|
||||
except Exception:
|
||||
pass
|
||||
session.last_activity = time.time()
|
||||
operator_text = ",".join(sorted(filter(None, session.operator_ids))) or "-"
|
||||
self._service_log_event(
|
||||
"vpn_tunnel_keepalive agent_id={0} tunnel_id={1} reason={2} operators={3}".format(
|
||||
session.agent_id,
|
||||
session.tunnel_id,
|
||||
reason or "-",
|
||||
operator_text,
|
||||
)
|
||||
)
|
||||
return True
|
||||
session = self._sessions_by_agent.pop(agent_id, None)
|
||||
if session:
|
||||
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.wg.remove_firewall_rules(session.firewall_rules)
|
||||
@@ -461,7 +517,14 @@ class VpnTunnelService:
|
||||
self._log_device_activity(session, event="stop", reason=reason)
|
||||
return True
|
||||
|
||||
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
|
||||
def disconnect_by_tunnel(
|
||||
self,
|
||||
tunnel_id: str,
|
||||
reason: str = "operator_stop",
|
||||
*,
|
||||
operator_id: Optional[str] = None,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
with self._lock:
|
||||
session = self._sessions_by_tunnel.get(tunnel_id)
|
||||
if not session:
|
||||
@@ -469,7 +532,7 @@ class VpnTunnelService:
|
||||
"vpn_tunnel_disconnect_missing tunnel_id={0} reason={1}".format(tunnel_id or "-", reason or "-")
|
||||
)
|
||||
return False
|
||||
return self.disconnect(session.agent_id, reason=reason)
|
||||
return self.disconnect(session.agent_id, reason=reason, operator_id=operator_id, force=force)
|
||||
|
||||
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
||||
if not self.socketio:
|
||||
@@ -704,7 +767,7 @@ class VpnTunnelService:
|
||||
"server_public_key": self.wg.server_public_key,
|
||||
"client_public_key": session.client_public_key,
|
||||
"client_private_key": session.client_private_key,
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"idle_seconds": 0 if self.persistent else self.idle_seconds,
|
||||
"allowed_ports": list(session.allowed_ports),
|
||||
"connected_operators": len([o for o in session.operator_ids if o]),
|
||||
}
|
||||
@@ -729,6 +792,6 @@ class VpnTunnelService:
|
||||
"last_activity_iso": self._ts_to_iso(session.last_activity),
|
||||
"expires_at": int(session.expires_at),
|
||||
"expires_at_iso": self._ts_to_iso(session.expires_at),
|
||||
"idle_seconds": self.idle_seconds,
|
||||
"idle_seconds": 0 if self.persistent else self.idle_seconds,
|
||||
"status": "up",
|
||||
}
|
||||
|
||||
@@ -164,6 +164,25 @@ class WireGuardServerManager:
|
||||
return match.group(1).upper()
|
||||
return None
|
||||
|
||||
def _service_exists(self) -> bool:
|
||||
code, _, _ = self._run_command(["sc.exe", "query", self._service_id()])
|
||||
return code == 0
|
||||
|
||||
def _stop_service(self, *, timeout: int = 20) -> bool:
|
||||
service_id = self._service_id()
|
||||
state = self._query_service_state()
|
||||
if not state:
|
||||
return False
|
||||
if state == "STOPPED":
|
||||
return True
|
||||
self._run_command(["sc.exe", "stop", service_id])
|
||||
for _ in range(max(1, timeout)):
|
||||
time.sleep(1)
|
||||
state = self._query_service_state()
|
||||
if state == "STOPPED":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _ensure_service_display_name(self) -> None:
|
||||
if not self._service_display_name:
|
||||
return
|
||||
@@ -172,9 +191,9 @@ class WireGuardServerManager:
|
||||
if code != 0 and err:
|
||||
self.logger.warning("Failed to set WireGuard service display name: %s", err)
|
||||
|
||||
def _ensure_service_running(self) -> None:
|
||||
def _ensure_service_running(self, *, timeout: int = 20) -> None:
|
||||
service_id = self._service_id()
|
||||
for _ in range(6):
|
||||
for _ in range(max(1, timeout)):
|
||||
state = self._query_service_state()
|
||||
if state == "RUNNING":
|
||||
return
|
||||
@@ -183,8 +202,20 @@ class WireGuardServerManager:
|
||||
if code != 0:
|
||||
self.logger.error("Failed to start WireGuard tunnel service %s err=%s", service_id, err)
|
||||
break
|
||||
if state in ("START_PENDING", "STOP_PENDING"):
|
||||
time.sleep(1)
|
||||
continue
|
||||
time.sleep(1)
|
||||
state = self._query_service_state()
|
||||
if state == "START_PENDING":
|
||||
self.logger.warning("WireGuard tunnel service still START_PENDING; attempting restart.")
|
||||
self._stop_service(timeout=10)
|
||||
self._run_command(["sc.exe", "start", service_id])
|
||||
for _ in range(10):
|
||||
time.sleep(1)
|
||||
if self._query_service_state() == "RUNNING":
|
||||
return
|
||||
state = self._query_service_state()
|
||||
raise RuntimeError(f"WireGuard tunnel service {service_id} failed to start (state={state})")
|
||||
|
||||
def _normalise_allowed_ports(
|
||||
@@ -329,6 +360,7 @@ class WireGuardServerManager:
|
||||
for idx, rule in enumerate(rules):
|
||||
name = f"Borealis-WG-Agent-{peer.get('agent_id','')}-{idx}"
|
||||
protocol = str(rule.get("protocol") or "TCP").upper()
|
||||
self._run_command(["netsh", "advfirewall", "firewall", "delete", "rule", f"name={name}"])
|
||||
args = [
|
||||
"netsh",
|
||||
"advfirewall",
|
||||
@@ -374,8 +406,12 @@ class WireGuardServerManager:
|
||||
config_path.write_text(rendered, encoding="utf-8")
|
||||
self.logger.info("Rendered WireGuard config to %s", config_path)
|
||||
|
||||
# Ensure old service is removed before re-installing.
|
||||
self.stop_listener()
|
||||
if self._service_exists():
|
||||
if not self._stop_service(timeout=20):
|
||||
self.logger.warning("WireGuard tunnel service did not stop cleanly before restart.")
|
||||
self._ensure_service_display_name()
|
||||
self._ensure_service_running(timeout=25)
|
||||
return
|
||||
|
||||
args = [self._wireguard_exe, "/installtunnelservice", str(config_path)]
|
||||
code, out, err = self._run_command(args)
|
||||
@@ -384,21 +420,22 @@ class WireGuardServerManager:
|
||||
raise RuntimeError(f"WireGuard installtunnelservice failed: {err}")
|
||||
self.logger.info("WireGuard listener installed (service=%s)", config_path.stem)
|
||||
self._ensure_service_display_name()
|
||||
self._ensure_service_running()
|
||||
self._ensure_service_running(timeout=25)
|
||||
|
||||
def stop_listener(self, *, ignore_missing: bool = False) -> None:
|
||||
"""Stop and remove the WireGuard tunnel service."""
|
||||
"""Stop the WireGuard tunnel service (leave installed for reuse)."""
|
||||
|
||||
args = [self._wireguard_exe, "/uninstalltunnelservice", self._service_name]
|
||||
code, out, err = self._run_command(args)
|
||||
if code != 0:
|
||||
err_text = " ".join([out or "", err or ""]).strip().lower()
|
||||
if ignore_missing and ("does not exist" in err_text or "not exist" in err_text):
|
||||
if not self._service_exists():
|
||||
if ignore_missing:
|
||||
self.logger.info("WireGuard tunnel service already absent")
|
||||
return
|
||||
self.logger.warning("Failed to uninstall WireGuard tunnel service code=%s err=%s", code, err)
|
||||
else:
|
||||
self.logger.info("WireGuard tunnel service removed")
|
||||
self.logger.warning("WireGuard tunnel service not found during stop.")
|
||||
return
|
||||
|
||||
if not self._stop_service(timeout=20):
|
||||
self.logger.warning("WireGuard tunnel service did not stop cleanly.")
|
||||
return
|
||||
self.logger.info("WireGuard tunnel service stopped")
|
||||
|
||||
def build_firewall_rules(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user