mirror of
https://github.com/bunny-lab-io/Borealis.git
synced 2025-12-18 06:35:48 -07:00
474 lines
18 KiB
Python
474 lines
18 KiB
Python
# ======================================================
|
|
# Data\Engine\services\VPN\vpn_tunnel_service.py
|
|
# Description: WireGuard tunnel orchestration (single tunnel per agent, token issuance, idle handling).
|
|
#
|
|
# API Endpoints (if applicable): None
|
|
# ======================================================
|
|
|
|
"""WireGuard tunnel orchestration helpers for the Engine runtime."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import ipaddress
|
|
import json
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
|
|
|
|
from .wireguard_server import WireGuardServerManager
|
|
|
|
|
|
@dataclass
|
|
class VpnSession:
|
|
tunnel_id: str
|
|
agent_id: str
|
|
virtual_ip: str
|
|
token: Dict[str, Any]
|
|
client_public_key: str
|
|
client_private_key: str
|
|
allowed_ports: Tuple[int, ...]
|
|
created_at: float
|
|
expires_at: float
|
|
last_activity: float
|
|
operator_ids: set[str] = field(default_factory=set)
|
|
firewall_rules: List[str] = field(default_factory=list)
|
|
activity_id: Optional[int] = None
|
|
hostname: Optional[str] = None
|
|
|
|
|
|
class VpnTunnelService:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
context: Any,
|
|
wireguard_manager: WireGuardServerManager,
|
|
db_conn_factory,
|
|
socketio,
|
|
service_log,
|
|
signer: Optional[Any] = None,
|
|
idle_seconds: int = 900,
|
|
) -> None:
|
|
self.context = context
|
|
self.wg = wireguard_manager
|
|
self.db_conn_factory = db_conn_factory
|
|
self.socketio = socketio
|
|
self.service_log = service_log
|
|
self.signer = signer
|
|
self.logger = context.logger.getChild("vpn_tunnel")
|
|
self.activity_logger = self.wg.logger.getChild("device_activity")
|
|
self.idle_seconds = max(60, int(idle_seconds))
|
|
self._lock = threading.Lock()
|
|
self._sessions_by_agent: Dict[str, VpnSession] = {}
|
|
self._sessions_by_tunnel: Dict[str, VpnSession] = {}
|
|
self._engine_ip = ipaddress.ip_interface(context.wireguard_engine_virtual_ip)
|
|
self._peer_network = ipaddress.ip_network(context.wireguard_peer_network, strict=False)
|
|
self._idle_thread = threading.Thread(target=self._idle_loop, daemon=True)
|
|
self._idle_thread.start()
|
|
|
|
def _idle_loop(self) -> None:
|
|
while True:
|
|
time.sleep(10)
|
|
now = time.time()
|
|
expired: List[VpnSession] = []
|
|
with self._lock:
|
|
for session in list(self._sessions_by_agent.values()):
|
|
if session.last_activity + self.idle_seconds <= now:
|
|
expired.append(session)
|
|
for session in expired:
|
|
self.disconnect(session.agent_id, reason="idle_timeout")
|
|
|
|
def _allocate_virtual_ip(self, agent_id: str) -> str:
|
|
existing = self._sessions_by_agent.get(agent_id)
|
|
if existing:
|
|
return existing.virtual_ip
|
|
|
|
used = {s.virtual_ip for s in self._sessions_by_agent.values()}
|
|
for host in self._peer_network.hosts():
|
|
if host == self._engine_ip.ip:
|
|
continue
|
|
candidate = f"{host}/32"
|
|
if candidate not in used:
|
|
return candidate
|
|
raise RuntimeError("vpn_ip_pool_exhausted")
|
|
|
|
def _load_allowed_ports(self, agent_id: str) -> Tuple[int, ...]:
|
|
default = tuple(self.context.wireguard_acl_allowlist_windows or ())
|
|
try:
|
|
conn = self.db_conn_factory()
|
|
cur = conn.cursor()
|
|
try:
|
|
cur.execute(
|
|
"SELECT operating_system FROM devices WHERE agent_id=? ORDER BY last_seen DESC LIMIT 1",
|
|
(agent_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
os_name = str(row[0]).lower() if row and row[0] else ""
|
|
except Exception:
|
|
os_name = ""
|
|
if os_name and "windows" not in os_name:
|
|
baseline = {5900, 3478}
|
|
filtered = [p for p in default if p in baseline]
|
|
if filtered:
|
|
default = tuple(filtered)
|
|
cur.execute(
|
|
"SELECT allowed_ports FROM device_vpn_config WHERE agent_id=?",
|
|
(agent_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return default
|
|
raw = row[0] or ""
|
|
ports = json.loads(raw) if raw else []
|
|
ports = [int(p) for p in ports if isinstance(p, (int, float, str))]
|
|
ports = [p for p in ports if 1 <= p <= 65535]
|
|
return tuple(dict.fromkeys(ports)) or default
|
|
except Exception:
|
|
return default
|
|
finally:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def _generate_client_keys(self) -> Tuple[str, str]:
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import x25519
|
|
|
|
key = x25519.X25519PrivateKey.generate()
|
|
priv = base64.b64encode(
|
|
key.private_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PrivateFormat.Raw,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
).decode("ascii").strip()
|
|
pub = base64.b64encode(
|
|
key.public_key().public_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PublicFormat.Raw,
|
|
)
|
|
).decode("ascii").strip()
|
|
return priv, pub
|
|
|
|
def _issue_token(self, agent_id: str, tunnel_id: str, expires_at: float) -> Dict[str, Any]:
|
|
payload = {
|
|
"agent_id": agent_id,
|
|
"tunnel_id": tunnel_id,
|
|
"port": self.context.wireguard_port,
|
|
"expires_at": expires_at,
|
|
"issued_at": time.time(),
|
|
}
|
|
if not self.signer:
|
|
return dict(payload)
|
|
|
|
token = dict(payload)
|
|
try:
|
|
payload_bytes = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8")
|
|
signature = self.signer.sign(payload_bytes)
|
|
token["signature"] = base64.b64encode(signature).decode("ascii")
|
|
if hasattr(self.signer, "public_base64_spki"):
|
|
token["signing_key"] = self.signer.public_base64_spki()
|
|
token["sig_alg"] = "ed25519"
|
|
except Exception:
|
|
self.logger.debug("Failed to sign VPN orchestration token; sending unsigned.", exc_info=True)
|
|
return token
|
|
|
|
def _refresh_listener(self) -> None:
|
|
peers: List[Mapping[str, object]] = []
|
|
for session in self._sessions_by_agent.values():
|
|
peer = self.wg.build_peer_profile(
|
|
session.agent_id,
|
|
session.virtual_ip,
|
|
allowed_ports=session.allowed_ports,
|
|
)
|
|
peer = dict(peer)
|
|
peer["public_key"] = session.client_public_key
|
|
peers.append(peer)
|
|
if not peers:
|
|
self.wg.stop_listener()
|
|
return
|
|
self.wg.start_listener(peers)
|
|
|
|
def connect(self, *, agent_id: str, operator_id: Optional[str]) -> Mapping[str, Any]:
|
|
now = time.time()
|
|
with self._lock:
|
|
existing = self._sessions_by_agent.get(agent_id)
|
|
if existing:
|
|
if operator_id:
|
|
existing.operator_ids.add(operator_id)
|
|
existing.last_activity = now
|
|
return self._session_payload(existing)
|
|
|
|
tunnel_id = uuid.uuid4().hex
|
|
virtual_ip = self._allocate_virtual_ip(agent_id)
|
|
allowed_ports = self._load_allowed_ports(agent_id)
|
|
client_private, client_public = self._generate_client_keys()
|
|
token = self._issue_token(agent_id, tunnel_id, now + 300)
|
|
self.wg.require_orchestration_token(token)
|
|
|
|
session = VpnSession(
|
|
tunnel_id=tunnel_id,
|
|
agent_id=agent_id,
|
|
virtual_ip=virtual_ip,
|
|
token=token,
|
|
client_public_key=client_public,
|
|
client_private_key=client_private,
|
|
allowed_ports=allowed_ports,
|
|
created_at=now,
|
|
expires_at=now + 300,
|
|
last_activity=now,
|
|
)
|
|
if operator_id:
|
|
session.operator_ids.add(operator_id)
|
|
self._sessions_by_agent[agent_id] = session
|
|
self._sessions_by_tunnel[tunnel_id] = session
|
|
|
|
try:
|
|
self._refresh_listener()
|
|
|
|
peer = self.wg.build_peer_profile(
|
|
agent_id,
|
|
virtual_ip,
|
|
allowed_ports=allowed_ports,
|
|
)
|
|
rule_names = self.wg.apply_firewall_rules(peer)
|
|
session.firewall_rules = rule_names
|
|
except Exception:
|
|
with self._lock:
|
|
self._sessions_by_agent.pop(agent_id, None)
|
|
self._sessions_by_tunnel.pop(tunnel_id, None)
|
|
try:
|
|
self._refresh_listener()
|
|
except Exception:
|
|
self.logger.debug("Failed to refresh WireGuard listener after connect rollback.", exc_info=True)
|
|
raise
|
|
|
|
payload = self._session_payload(session)
|
|
self._emit_start(payload)
|
|
self._log_device_activity(session, event="start")
|
|
return payload
|
|
|
|
def status(self, agent_id: str) -> Optional[Mapping[str, Any]]:
|
|
with self._lock:
|
|
session = self._sessions_by_agent.get(agent_id)
|
|
if not session:
|
|
return None
|
|
return self._session_payload(session, include_token=False)
|
|
|
|
def bump_activity(self, agent_id: str) -> None:
|
|
with self._lock:
|
|
session = self._sessions_by_agent.get(agent_id)
|
|
if not session:
|
|
return
|
|
session.last_activity = time.time()
|
|
try:
|
|
if self.socketio:
|
|
self.socketio.emit("vpn_tunnel_activity", {"agent_id": agent_id}, namespace="/")
|
|
except Exception:
|
|
self.logger.debug("vpn_tunnel_activity emit failed for agent_id=%s", agent_id, exc_info=True)
|
|
|
|
def disconnect(self, agent_id: str, reason: str = "operator_stop") -> bool:
|
|
with self._lock:
|
|
session = self._sessions_by_agent.pop(agent_id, None)
|
|
if not session:
|
|
return False
|
|
self._sessions_by_tunnel.pop(session.tunnel_id, None)
|
|
|
|
try:
|
|
self.wg.remove_firewall_rules(session.firewall_rules)
|
|
except Exception:
|
|
self.logger.debug("Failed to remove firewall rules for agent=%s", agent_id, exc_info=True)
|
|
|
|
self._refresh_listener()
|
|
self._emit_stop(session, reason)
|
|
self._log_device_activity(session, event="stop", reason=reason)
|
|
return True
|
|
|
|
def disconnect_by_tunnel(self, tunnel_id: str, reason: str = "operator_stop") -> bool:
|
|
with self._lock:
|
|
session = self._sessions_by_tunnel.get(tunnel_id)
|
|
if not session:
|
|
return False
|
|
return self.disconnect(session.agent_id, reason=reason)
|
|
|
|
def _emit_start(self, payload: Mapping[str, Any]) -> None:
|
|
if not self.socketio:
|
|
return
|
|
try:
|
|
self.socketio.emit("vpn_tunnel_start", payload, namespace="/")
|
|
except Exception:
|
|
self.logger.debug("vpn_tunnel_start emit failed", exc_info=True)
|
|
|
|
def _emit_stop(self, session: VpnSession, reason: str) -> None:
|
|
if not self.socketio:
|
|
return
|
|
try:
|
|
self.socketio.emit(
|
|
"vpn_tunnel_stop",
|
|
{"agent_id": session.agent_id, "tunnel_id": session.tunnel_id, "reason": reason},
|
|
namespace="/",
|
|
)
|
|
except Exception:
|
|
self.logger.debug("vpn_tunnel_stop emit failed", exc_info=True)
|
|
|
|
def _log_device_activity(self, session: VpnSession, *, event: str, reason: Optional[str] = None) -> None:
|
|
if self.db_conn_factory is None:
|
|
self.activity_logger.info(
|
|
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s",
|
|
event,
|
|
session.agent_id,
|
|
session.tunnel_id,
|
|
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
|
reason or "-",
|
|
)
|
|
return
|
|
|
|
conn = None
|
|
try:
|
|
conn = self.db_conn_factory()
|
|
cur = conn.cursor()
|
|
hostname = session.hostname
|
|
if not hostname:
|
|
try:
|
|
cur.execute(
|
|
"SELECT hostname FROM devices WHERE agent_id = ? ORDER BY last_seen DESC LIMIT 1",
|
|
(session.agent_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row and row[0]:
|
|
hostname = str(row[0]).strip()
|
|
session.hostname = hostname
|
|
except Exception:
|
|
hostname = None
|
|
|
|
if not hostname:
|
|
self.activity_logger.info(
|
|
"device_activity event=%s agent_id=%s tunnel_id=%s operator=%s reason=%s hostname=unknown",
|
|
event,
|
|
session.agent_id,
|
|
session.tunnel_id,
|
|
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
|
reason or "-",
|
|
)
|
|
return
|
|
|
|
now_ts = int(time.time())
|
|
script_name = "Reverse VPN Tunnel (WireGuard)"
|
|
|
|
if event == "start":
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO activity_history(hostname, script_path, script_name, script_type, ran_at, status, stdout, stderr)
|
|
VALUES(?,?,?,?,?,?,?,?)
|
|
""",
|
|
(
|
|
hostname,
|
|
session.tunnel_id,
|
|
script_name,
|
|
"vpn_tunnel",
|
|
now_ts,
|
|
"Running",
|
|
"",
|
|
"",
|
|
),
|
|
)
|
|
session.activity_id = cur.lastrowid
|
|
conn.commit()
|
|
if self.socketio:
|
|
try:
|
|
self.socketio.emit(
|
|
"device_activity_changed",
|
|
{
|
|
"hostname": hostname,
|
|
"activity_id": session.activity_id,
|
|
"change": "created",
|
|
"source": "vpn_tunnel",
|
|
},
|
|
)
|
|
except Exception:
|
|
pass
|
|
self.activity_logger.info(
|
|
"device_activity_start hostname=%s agent_id=%s tunnel_id=%s operator=%s activity_id=%s",
|
|
hostname,
|
|
session.agent_id,
|
|
session.tunnel_id,
|
|
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
|
session.activity_id or "-",
|
|
)
|
|
return
|
|
|
|
if session.activity_id:
|
|
status = "Completed" if event == "stop" else "Closed"
|
|
cur.execute(
|
|
"""
|
|
UPDATE activity_history
|
|
SET status=?,
|
|
stderr=COALESCE(stderr, '') || ?
|
|
WHERE id=?
|
|
""",
|
|
(
|
|
status,
|
|
f"\nreason: {reason}" if reason else "",
|
|
session.activity_id,
|
|
),
|
|
)
|
|
conn.commit()
|
|
if self.socketio:
|
|
try:
|
|
self.socketio.emit(
|
|
"device_activity_changed",
|
|
{
|
|
"hostname": hostname,
|
|
"activity_id": session.activity_id,
|
|
"change": "updated",
|
|
"source": "vpn_tunnel",
|
|
},
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
self.activity_logger.info(
|
|
"device_activity event=%s hostname=%s agent_id=%s tunnel_id=%s operator=%s reason=%s activity_id=%s",
|
|
event,
|
|
hostname,
|
|
session.agent_id,
|
|
session.tunnel_id,
|
|
",".join(sorted(filter(None, session.operator_ids))) or "-",
|
|
reason or "-",
|
|
session.activity_id or "-",
|
|
)
|
|
except Exception:
|
|
self.activity_logger.debug(
|
|
"device_activity logging failed for tunnel_id=%s",
|
|
session.tunnel_id,
|
|
exc_info=True,
|
|
)
|
|
finally:
|
|
if conn is not None:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def _session_payload(self, session: VpnSession, *, include_token: bool = True) -> Mapping[str, Any]:
|
|
payload: Dict[str, Any] = {
|
|
"tunnel_id": session.tunnel_id,
|
|
"agent_id": session.agent_id,
|
|
"virtual_ip": session.virtual_ip,
|
|
"engine_virtual_ip": str(self._engine_ip.ip),
|
|
"allowed_ips": f"{self._engine_ip.ip}/32",
|
|
"endpoint": f"{self._engine_ip.ip}:{self.context.wireguard_port}",
|
|
"server_public_key": self.wg.server_public_key,
|
|
"client_public_key": session.client_public_key,
|
|
"client_private_key": session.client_private_key,
|
|
"idle_seconds": self.idle_seconds,
|
|
"allowed_ports": list(session.allowed_ports),
|
|
"connected_operators": len([o for o in session.operator_ids if o]),
|
|
}
|
|
if include_token:
|
|
payload["token"] = session.token
|
|
return payload
|