Files
Borealis-Github-Replica/tests/test_enrollment_install_codes.py

214 lines
7.4 KiB
Python

import base64
import os
import pathlib
import sqlite3
import sys
import uuid
import pytest
try: # pragma: no cover - optional dependency
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ed25519
_CRYPTO_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
serialization = None # type: ignore
ed25519 = None # type: ignore
_CRYPTO_IMPORT_ERROR = exc
try: # pragma: no cover - optional dependency
from flask import Flask
_FLASK_IMPORT_ERROR: Exception | None = None
except Exception as exc: # pragma: no cover - dependency unavailable
Flask = None # type: ignore
_FLASK_IMPORT_ERROR = exc
ROOT = pathlib.Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from Data.Server.Modules import db_migrations
from Data.Server.Modules.auth.rate_limit import SlidingWindowRateLimiter
from Data.Server.Modules.enrollment.nonce_store import NonceCache
if Flask is not None: # pragma: no cover - dependency unavailable
from Data.Server.Modules.enrollment import routes as enrollment_routes
else: # pragma: no cover - dependency unavailable
enrollment_routes = None # type: ignore
class _DummyJWTService:
def issue_access_token(self, guid: str, fingerprint: str, token_version: int, expires_in: int = 900, extra_claims=None):
return f"token-{guid}"
class _DummySigner:
def public_base64_spki(self) -> str:
return ""
def _make_app(db_path: str, tls_path: str):
if Flask is None or enrollment_routes is None: # pragma: no cover - dependency unavailable
pytest.skip(f"flask unavailable: {_FLASK_IMPORT_ERROR}")
app = Flask(__name__)
def _factory():
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
return conn
enrollment_routes.register(
app,
db_conn_factory=_factory,
log=lambda channel, message: None,
jwt_service=_DummyJWTService(),
tls_bundle_path=tls_path,
ip_rate_limiter=SlidingWindowRateLimiter(),
fp_rate_limiter=SlidingWindowRateLimiter(),
nonce_cache=NonceCache(ttl_seconds=30.0),
script_signer=_DummySigner(),
)
return app, _factory
def _create_install_code(conn: sqlite3.Connection, code: str, *, max_uses: int = 2):
cur = conn.cursor()
record_id = str(uuid.uuid4())
cur.execute(
"""
INSERT INTO enrollment_install_codes (
id, code, expires_at, created_by_user_id, max_uses, use_count
)
VALUES (?, ?, datetime('now', '+6 hours'), 'test-user', ?, 0)
""",
(record_id, code, max_uses),
)
conn.commit()
return record_id
def _perform_enrollment_cycle(app, factory, code: str, private_key):
client = app.test_client()
public_der = private_key.public_key().public_bytes(
serialization.Encoding.DER,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_b64 = base64.b64encode(public_der).decode("ascii")
client_nonce = os.urandom(32)
payload = {
"hostname": "unit-test-host",
"enrollment_code": code,
"agent_pubkey": public_b64,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
}
request_resp = client.post("/api/agent/enroll/request", json=payload)
assert request_resp.status_code == 200
request_data = request_resp.get_json()
approval_reference = request_data["approval_reference"]
with factory() as conn:
cur = conn.cursor()
cur.execute(
"""
UPDATE device_approvals
SET status = 'approved',
approved_by_user_id = 'tester'
WHERE approval_reference = ?
""",
(approval_reference,),
)
cur.execute(
"""
SELECT server_nonce, client_nonce
FROM device_approvals
WHERE approval_reference = ?
""",
(approval_reference,),
)
row = cur.fetchone()
assert row is not None
server_nonce_b64 = row["server_nonce"]
server_nonce = base64.b64decode(server_nonce_b64)
proof_message = server_nonce + approval_reference.encode("utf-8") + client_nonce
proof_sig = private_key.sign(proof_message)
poll_payload = {
"approval_reference": approval_reference,
"client_nonce": base64.b64encode(client_nonce).decode("ascii"),
"proof_sig": base64.b64encode(proof_sig).decode("ascii"),
}
poll_resp = client.post("/api/agent/enroll/poll", json=poll_payload)
assert poll_resp.status_code == 200
return poll_resp.get_json()
@pytest.mark.parametrize("max_uses", [2])
@pytest.mark.skipif(ed25519 is None, reason=f"cryptography unavailable: {_CRYPTO_IMPORT_ERROR}")
@pytest.mark.skipif(Flask is None, reason=f"flask unavailable: {_FLASK_IMPORT_ERROR}")
def test_install_code_allows_multiple_and_reuse(tmp_path, max_uses):
db_path = tmp_path / "test.db"
conn = sqlite3.connect(db_path)
db_migrations.apply_all(conn)
_create_install_code(conn, "TEST-CODE-1234", max_uses=max_uses)
conn.close()
tls_path = tmp_path / "tls.pem"
tls_path.write_text("TEST CERT")
app, factory = _make_app(str(db_path), str(tls_path))
private_key = ed25519.Ed25519PrivateKey.generate()
# First enrollment consumes one use but keeps the code active.
first = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert first["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, max_uses, used_at, last_used_at FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == 1
assert row["max_uses"] == max_uses
assert row["used_at"] is None
assert row["last_used_at"] is not None
# Second enrollment hits the configured max uses.
second = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert second["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses
assert row["used_at"] is not None
assert row["last_used_at"] is not None
consumed_guid = row["used_by_guid"]
assert consumed_guid
# Additional enrollments from the same identity reuse the stored GUID even after consumption.
third = _perform_enrollment_cycle(app, factory, "TEST-CODE-1234", private_key)
assert third["status"] == "approved"
with factory() as conn:
cur = conn.cursor()
cur.execute(
"SELECT use_count, used_at, last_used_at, used_by_guid FROM enrollment_install_codes WHERE code = ?",
("TEST-CODE-1234",),
)
row = cur.fetchone()
assert row is not None
assert row["use_count"] == max_uses + 1
assert row["used_by_guid"] == consumed_guid
assert row["used_at"] is not None
assert row["last_used_at"] is not None
cur.execute("SELECT COUNT(*) FROM devices WHERE guid = ?", (consumed_guid,))
assert cur.fetchone()[0] == 1