diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/auth.py b/app/auth.py new file mode 100644 index 0000000..ff4eac3 --- /dev/null +++ b/app/auth.py @@ -0,0 +1,102 @@ +import secrets +from typing import Optional + +from fastapi import Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from itsdangerous import BadSignature, URLSafeTimedSerializer +from passlib.context import CryptContext +from sqlalchemy.orm import Session + +from config import COOKIE_MAX_AGE, COOKIE_NAME, CSRF_COOKIE +from database import get_db +from models import User, UserServiceAccess +from utils import now_utc +from sqlalchemy import select + +import os + +_SIGNING_KEY = os.getenv("SIGNING_KEY", secrets.token_urlsafe(32)) +serializer = URLSafeTimedSerializer(_SIGNING_KEY, salt="portal-auth") +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def hash_password(password: str) -> str: + return pwd_context.hash(password) + + +def verify_password(password: str, password_hash: str) -> bool: + return pwd_context.verify(password, password_hash) + + +def user_is_valid(user: User) -> bool: + return bool(user.active and user.expires_at > now_utc()) + + +def issue_auth_cookie(response: RedirectResponse, user: User) -> None: + token = serializer.dumps({"user_id": user.id}) + response.set_cookie( + key=COOKIE_NAME, + value=token, + httponly=True, + secure=True, + samesite="strict", + max_age=COOKIE_MAX_AGE, + path="/", + ) + + +def issue_csrf_cookie(response: RedirectResponse) -> str: + token = secrets.token_urlsafe(24) + response.set_cookie( + key=CSRF_COOKIE, + value=token, + httponly=False, + secure=True, + samesite="lax", + max_age=COOKIE_MAX_AGE, + path="/", + ) + return token + + +def get_current_user(request: Request, db: Session = Depends(get_db)) -> Optional[User]: + raw = request.cookies.get(COOKIE_NAME) + if not raw: + return None + try: + payload = serializer.loads(raw, max_age=COOKIE_MAX_AGE) + except BadSignature: + return None + user = db.get(User, int(payload["user_id"])) + if not user or not user_is_valid(user): + return None + return user + + +def require_user(user: Optional[User] = Depends(get_current_user)) -> User: + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") + return user + + +def require_admin(user: User = Depends(require_user)) -> User: + if not user.is_admin: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin only") + return user + + +def validate_csrf(request: Request) -> None: + cookie = request.cookies.get(CSRF_COOKIE) + form_val = request.headers.get("X-CSRF-Token") + if request.headers.get("content-type", "").startswith("application/x-www-form-urlencoded"): + return + if not cookie or not form_val or cookie != form_val: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF failed") + + +def has_access(db: Session, user_id: int, service_id: int) -> bool: + q = select(UserServiceAccess).where( + UserServiceAccess.user_id == user_id, + UserServiceAccess.service_id == service_id, + ) + return db.scalar(q) is not None diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..1e63b66 --- /dev/null +++ b/app/config.py @@ -0,0 +1,35 @@ +import os +from pathlib import Path + +DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+psycopg2://portal:portal@db:5432/portal") +COOKIE_NAME = "portal_auth" +CSRF_COOKIE = "csrf_token" +COOKIE_MAX_AGE = 8 * 60 * 60 +SESSION_IDLE_SECONDS = int(os.getenv("SESSION_IDLE_SECONDS", "7200")) +PUBLIC_HOST = os.getenv("PUBLIC_HOST", "stend.4mont.ru") +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() +LOG_SLOW_REQUEST_MS = int(os.getenv("LOG_SLOW_REQUEST_MS", "2000")) +GO_USER_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_USER_LOCK_TIMEOUT_SECONDS", "8.0")) +GO_POOL_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_POOL_LOCK_TIMEOUT_SECONDS", "20.0")) +POOL_DISPATCH_RETRIES = int(os.getenv("POOL_DISPATCH_RETRIES", "6")) +POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS = float(os.getenv("POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS", "2.0")) +POOL_DISPATCH_SLEEP_SECONDS = float(os.getenv("POOL_DISPATCH_SLEEP_SECONDS", "0.3")) +TRAEFIK_INTERNAL_URL = os.getenv("TRAEFIK_INTERNAL_URL", "http://traefik") +PREWARM_POOL_SIZE = int(os.getenv("PREWARM_POOL_SIZE", "2")) +UNIVERSAL_POOL_SIZE = int(os.getenv("UNIVERSAL_POOL_SIZE", "0")) +WEB_POOL_SIZE = int(os.getenv("WEB_POOL_SIZE", "20")) +WEB_POOL_BUFFER = int(os.getenv("WEB_POOL_BUFFER", "2")) +X11VNC_FLAGS = os.getenv("X11VNC_FLAGS", "-wait 5 -defer 5 -threads") +MAX_ACTIVE_SERVICES_PER_USER = int(os.getenv("MAX_ACTIVE_SERVICES_PER_USER", "4")) +WEB_RESOLUTION_MIN_WIDTH = int(os.getenv("WEB_RESOLUTION_MIN_WIDTH", "1024")) +WEB_RESOLUTION_MIN_HEIGHT = int(os.getenv("WEB_RESOLUTION_MIN_HEIGHT", "720")) +WEB_RESOLUTION_MAX_WIDTH = int(os.getenv("WEB_RESOLUTION_MAX_WIDTH", "3840")) +WEB_RESOLUTION_MAX_HEIGHT = int(os.getenv("WEB_RESOLUTION_MAX_HEIGHT", "2160")) +ENABLE_STARTUP_MAINTENANCE = os.getenv("ENABLE_STARTUP_MAINTENANCE", "1") == "1" +ICON_UPLOAD_MAX_BYTES = 2 * 1024 * 1024 +ICON_UPLOAD_TYPES = { + "image/png": "png", + "image/jpeg": "jpg", + "image/webp": "webp", +} +SERVICE_ICONS_DIR = Path("static/service-icons") diff --git a/app/database.py b/app/database.py new file mode 100644 index 0000000..50b34be --- /dev/null +++ b/app/database.py @@ -0,0 +1,20 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import DeclarativeBase, sessionmaker + +from config import DATABASE_URL + + +engine = create_engine(DATABASE_URL, pool_pre_ping=True) +SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) + + +class Base(DeclarativeBase): + pass + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/app/main.py b/app/main.py index dd73004..07af960 100644 --- a/app/main.py +++ b/app/main.py @@ -1,107 +1,60 @@ import datetime as dt -import enum -import fcntl -import json -import re import logging -import os -from pathlib import Path -import secrets -import threading -import time +import re import uuid +import time import contextvars -from urllib.parse import parse_qs, unquote, urlparse from typing import Optional -import docker -import requests -import mistune from fastapi import Depends, FastAPI, File, Form, HTTPException, Query, Request, UploadFile, status from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from itsdangerous import BadSignature, URLSafeTimedSerializer from markupsafe import Markup, escape -from passlib.context import CryptContext -from sqlalchemy import ( - Boolean, - DateTime, - Enum, - ForeignKey, - Integer, - String, - Text, - UniqueConstraint, - create_engine, - select, - text, +from sqlalchemy import select +from sqlalchemy.orm import Session +from starlette.responses import HTMLResponse as _HR + +from config import ( + COOKIE_NAME, CSRF_COOKIE, LOG_LEVEL, LOG_SLOW_REQUEST_MS, + MAX_ACTIVE_SERVICES_PER_USER, PUBLIC_HOST, SESSION_IDLE_SECONDS, + WEB_POOL_SIZE, ) -from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker - - -DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+psycopg2://portal:portal@db:5432/portal") -COOKIE_NAME = "portal_auth" -CSRF_COOKIE = "csrf_token" -COOKIE_MAX_AGE = 8 * 60 * 60 -SESSION_IDLE_SECONDS = int(os.getenv("SESSION_IDLE_SECONDS", "7200")) -PUBLIC_HOST = os.getenv("PUBLIC_HOST", "stend.4mont.ru") -LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() -LOG_SLOW_REQUEST_MS = int(os.getenv("LOG_SLOW_REQUEST_MS", "2000")) -GO_USER_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_USER_LOCK_TIMEOUT_SECONDS", "8.0")) -GO_POOL_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_POOL_LOCK_TIMEOUT_SECONDS", "20.0")) -POOL_DISPATCH_RETRIES = int(os.getenv("POOL_DISPATCH_RETRIES", "6")) -POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS = float(os.getenv("POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS", "2.0")) -POOL_DISPATCH_SLEEP_SECONDS = float(os.getenv("POOL_DISPATCH_SLEEP_SECONDS", "0.3")) -TRAEFIK_INTERNAL_URL = os.getenv("TRAEFIK_INTERNAL_URL", "http://traefik") -PREWARM_POOL_SIZE = int(os.getenv("PREWARM_POOL_SIZE", "2")) -UNIVERSAL_POOL_SIZE = int(os.getenv("UNIVERSAL_POOL_SIZE", "0")) -WEB_POOL_SIZE = int(os.getenv("WEB_POOL_SIZE", "20")) -WEB_POOL_BUFFER = int(os.getenv("WEB_POOL_BUFFER", "2")) -X11VNC_FLAGS = os.getenv("X11VNC_FLAGS", "-wait 5 -defer 5 -threads") -MAX_ACTIVE_SERVICES_PER_USER = int(os.getenv("MAX_ACTIVE_SERVICES_PER_USER", "4")) -WEB_RESOLUTION_MIN_WIDTH = int(os.getenv("WEB_RESOLUTION_MIN_WIDTH", "1024")) -WEB_RESOLUTION_MIN_HEIGHT = int(os.getenv("WEB_RESOLUTION_MIN_HEIGHT", "720")) -WEB_RESOLUTION_MAX_WIDTH = int(os.getenv("WEB_RESOLUTION_MAX_WIDTH", "3840")) -WEB_RESOLUTION_MAX_HEIGHT = int(os.getenv("WEB_RESOLUTION_MAX_HEIGHT", "2160")) -ENABLE_STARTUP_MAINTENANCE = os.getenv("ENABLE_STARTUP_MAINTENANCE", "1") == "1" -ICON_UPLOAD_MAX_BYTES = 2 * 1024 * 1024 -ICON_UPLOAD_TYPES = { - "image/png": "png", - "image/jpeg": "jpg", - "image/webp": "webp", -} -SERVICE_ICONS_DIR = Path("static/service-icons") +from database import get_db +from models import ( + AuditLog, Category, RdpSlot, Service, ServiceCategory, ServiceType, + SessionModel, SessionStatus, User, UserServiceAccess, +) +from utils import ( + audit, ensure_icons_dir, format_service_comment, log_event, normalize_web_target, + now_utc, parse_rdp_target, remove_icon_file, request_id_ctx, set_service_categories, + session_closed_reason, store_service_icon, +) +from auth import ( + get_current_user, has_access, issue_auth_cookie, issue_csrf_cookie, + require_admin, require_user, validate_csrf, verify_password, +) +from runtime import ( + acquire_universal_slot, acquire_web_pool_slot, allocator_lock, + container_running, create_runtime_container, desired_pool_size, + dispatch_universal_target, dispatch_web_pool_target, + ensure_warm_pool, ensure_web_pool, find_active_session_for_service, + find_active_session_for_user_service, get_active_sessions_count, + get_pool_detailed_status, get_pool_status_for_service, + get_web_pool_status, LockTimeoutError, open_warm_web_url, + _rdp_slot_container_name, sanitize_client_resolution, + service_uses_universal_pool, session_redirect_url, + start_rdp_slot_container, stop_rdp_slot_container, + stop_runtime_container, terminate_active_slot_sessions, + terminate_session_record, wait_for_session_route, +) +from maintenance import on_startup logging.basicConfig( level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) logger = logging.getLogger("portal") -request_id_ctx = contextvars.ContextVar("request_id", default="-") -maintenance_lock_file = None - - -def _normalize_log_value(value): - if isinstance(value, (str, int, float, bool)) or value is None: - return value - if isinstance(value, dt.datetime): - return value.isoformat() - return str(value) - - -def log_event(event: str, level: int = logging.INFO, **fields) -> None: - payload = {"event": event, "req_id": request_id_ctx.get()} - for key, value in fields.items(): - payload[key] = _normalize_log_value(value) - logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":"))) - -SIGNING_KEY = os.getenv("SIGNING_KEY", secrets.token_urlsafe(32)) -serializer = URLSafeTimedSerializer(SIGNING_KEY, salt="portal-auth") -pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") - -engine = create_engine(DATABASE_URL, pool_pre_ping=True) -SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False) templates = Jinja2Templates(directory="templates") app = FastAPI(title="МОНТ - инфрастуктурный полигон") @@ -159,12 +112,9 @@ async def request_logging_middleware(request: Request, call_next): return response - -import re as _re_mob - -_MOBILE_UA_RE = _re_mob.compile( +_MOBILE_UA_RE = re.compile( r"(Mobile|Android|iPhone|iPad|iPod|BlackBerry|IEMobile|Opera Mini|webOS)", - _re_mob.IGNORECASE, + re.IGNORECASE, ) _MOBILE_PAGE = ( "" @@ -201,1621 +151,13 @@ async def mobile_block_middleware(request: Request, call_next): return await call_next(request) ua = request.headers.get("user-agent", "") if _MOBILE_UA_RE.search(ua): - from starlette.responses import HTMLResponse as _HR return _HR(content=_MOBILE_PAGE, status_code=200) return await call_next(request) -class Base(DeclarativeBase): - pass - - -class ServiceType(str, enum.Enum): - WEB = "WEB" - VNC = "VNC" - RDP = "RDP" - - -class SessionStatus(str, enum.Enum): - ACTIVE = "ACTIVE" - EXPIRED = "EXPIRED" - TERMINATED = "TERMINATED" - ROTATED = "ROTATED" - - -class User(Base): - __tablename__ = "users" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - username: Mapped[str] = mapped_column(String(64), unique=True, index=True) - password_hash: Mapped[str] = mapped_column(String(255)) - expires_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), index=True) - active: Mapped[bool] = mapped_column(Boolean, default=True, index=True) - is_admin: Mapped[bool] = mapped_column(Boolean, default=False) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class Service(Base): - __tablename__ = "services" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - name: Mapped[str] = mapped_column(String(128)) - slug: Mapped[str] = mapped_column(String(64), unique=True, index=True) - type: Mapped[ServiceType] = mapped_column(Enum(ServiceType), index=True) - target: Mapped[str] = mapped_column(Text) - comment: Mapped[str] = mapped_column(Text, default="") - svc_login: Mapped[str] = mapped_column(String(256), default="") - svc_password: Mapped[str] = mapped_column(String(256), default="") - svc_cred_hint: Mapped[str] = mapped_column(Text, default="") - icon_path: Mapped[str] = mapped_column(Text, default="") - active: Mapped[bool] = mapped_column(Boolean, default=True) - warm_pool_size: Mapped[int] = mapped_column(Integer, default=0) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class Category(Base): - __tablename__ = "categories" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - name: Mapped[str] = mapped_column(String(128), unique=True, index=True) - slug: Mapped[str] = mapped_column(String(64), unique=True, index=True) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class ServiceCategory(Base): - __tablename__ = "service_categories" - __table_args__ = (UniqueConstraint("service_id", "category_id", name="uq_service_category"),) - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) - category_id: Mapped[int] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE"), index=True) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class UserServiceAccess(Base): - __tablename__ = "user_service_access" - __table_args__ = (UniqueConstraint("user_id", "service_id", name="uq_user_service"),) - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) - service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) - granted_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class RdpSlot(Base): - __tablename__ = "rdp_slots" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) - rdp_username: Mapped[str] = mapped_column(String(128)) - rdp_password: Mapped[str] = mapped_column(String(256), default="") - container_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) - - -class SessionModel(Base): - __tablename__ = "sessions" - - id: Mapped[str] = mapped_column(String(36), primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) - service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) - status: Mapped[SessionStatus] = mapped_column(Enum(SessionStatus), default=SessionStatus.ACTIVE, index=True) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) - last_access_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) - container_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) - - -class AuditLog(Base): - __tablename__ = "audit_logs" - - id: Mapped[int] = mapped_column(Integer, primary_key=True) - user_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True) - action: Mapped[str] = mapped_column(String(128), index=True) - details: Mapped[str] = mapped_column(Text) - created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) - - -def now_utc() -> dt.datetime: - return dt.datetime.now(dt.timezone.utc) - - -def session_closed_reason(sess: SessionModel, db: Session) -> str: - if not sess: - return "idle" - if sess.status == SessionStatus.EXPIRED: - return "idle" - if sess.status == SessionStatus.ROTATED: - return "limit" - if sess.status == SessionStatus.TERMINATED: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - active_rows = db.scalars( - select(SessionModel).where( - SessionModel.user_id == sess.user_id, - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.last_access_at >= cutoff, - ) - ).all() - active_service_ids = {row.service_id for row in active_rows} - if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids: - return "limit" - return "idle" - - -def normalize_web_target(url: str) -> str: - raw = (url or "").strip() - if not raw: - return raw - if raw.startswith(("http://", "https://")): - return raw - return f"http://{raw}" - - -_md = mistune.create_markdown( - escape=True, - plugins=["strikethrough", "table", "task_lists"], -) - -def format_service_comment(raw_text: str) -> Markup: - raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip() - if not raw: - return Markup("") - return Markup(_md(raw)) - -def parse_rdp_target(target: str) -> dict: - raw = (target or "").strip() - if not raw: - raise HTTPException(status_code=400, detail="Empty RDP target") - - parsed = urlparse(raw if "://" in raw else f"//{raw}") - host = parsed.hostname - if not host: - raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port") - port = parsed.port or 3389 - - username = unquote(parsed.username) if parsed.username else "" - password = unquote(parsed.password) if parsed.password else "" - - query = parse_qs(parsed.query or "") - if not username: - username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip() - if not password: - password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip() - - domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip() - security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower() - if security and security not in {"nla", "tls", "rdp"}: - raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp") - - return { - "host": host, - "port": str(port), - "user": username, - "password": password, - "domain": domain, - "security": security, - } - - -def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None: - normalized = sorted({int(x) for x in (category_ids or [])}) - if normalized: - existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all()) - missing = sorted(set(normalized) - existing_ids) - if missing: - raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}") - - existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all() - current = {row.category_id: row for row in existing_links} - wanted = set(normalized) - - for cat_id in wanted: - if cat_id not in current: - db.add(ServiceCategory(service_id=service_id, category_id=cat_id)) - for cat_id, row in current.items(): - if cat_id not in wanted: - db.delete(row) - - -def service_uses_universal_pool(service: Service) -> bool: - return UNIVERSAL_POOL_SIZE > 0 and service.type == ServiceType.RDP - - -def universal_container_name(slot: int) -> str: - return f"portal-universal-{slot}" - - -def web_pool_container_name(slot: int) -> str: - return f"portal-webpool-{slot}" - - -def ensure_icons_dir() -> None: - SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True) - - -def remove_icon_file(icon_path: str) -> None: - if not icon_path or not icon_path.startswith("/static/service-icons/"): - return - filename = icon_path.rsplit("/", 1)[-1] - candidate = SERVICE_ICONS_DIR / filename - try: - candidate.unlink(missing_ok=True) - except Exception: - logger.exception("icon_delete_failed path=%s", candidate) - - -async def store_service_icon(service: Service, upload: UploadFile) -> str: - ensure_icons_dir() - content_type = (upload.content_type or "").lower().strip() - ext = ICON_UPLOAD_TYPES.get(content_type) - if not ext: - raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP") - - payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1) - if len(payload) > ICON_UPLOAD_MAX_BYTES: - raise HTTPException(status_code=400, detail="File too large. Max 2MB") - if not payload: - raise HTTPException(status_code=400, detail="Empty file") - - stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S") - filename = f"svc_{service.id}_{stamp}.{ext}" - target = SERVICE_ICONS_DIR / filename - target.write_bytes(payload) - return f"/static/service-icons/{filename}" - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() - - -def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None: - db.add(AuditLog(user_id=user_id, action=action, details=details)) - db.commit() - - -def hash_password(password: str) -> str: - return pwd_context.hash(password) - - -def verify_password(password: str, password_hash: str) -> bool: - return pwd_context.verify(password, password_hash) - - -def user_is_valid(user: User) -> bool: - return bool(user.active and user.expires_at > now_utc()) - - -def issue_auth_cookie(response: RedirectResponse, user: User) -> None: - token = serializer.dumps({"user_id": user.id}) - response.set_cookie( - key=COOKIE_NAME, - value=token, - httponly=True, - secure=True, - samesite="strict", - max_age=COOKIE_MAX_AGE, - path="/", - ) - - -def issue_csrf_cookie(response: RedirectResponse) -> str: - token = secrets.token_urlsafe(24) - response.set_cookie( - key=CSRF_COOKIE, - value=token, - httponly=False, - secure=True, - samesite="lax", - max_age=COOKIE_MAX_AGE, - path="/", - ) - return token - - -def get_current_user(request: Request, db: Session = Depends(get_db)) -> Optional[User]: - raw = request.cookies.get(COOKIE_NAME) - if not raw: - return None - try: - payload = serializer.loads(raw, max_age=COOKIE_MAX_AGE) - except BadSignature: - return None - user = db.get(User, int(payload["user_id"])) - if not user or not user_is_valid(user): - return None - return user - - -def require_user(user: Optional[User] = Depends(get_current_user)) -> User: - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") - return user - - -def require_admin(user: User = Depends(require_user)) -> User: - if not user.is_admin: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin only") - return user - - -def validate_csrf(request: Request) -> None: - cookie = request.cookies.get(CSRF_COOKIE) - form_val = request.headers.get("X-CSRF-Token") - if request.headers.get("content-type", "").startswith("application/x-www-form-urlencoded"): - return - if not cookie or not form_val or cookie != form_val: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF failed") - - -def has_access(db: Session, user_id: int, service_id: int) -> bool: - q = select(UserServiceAccess).where( - UserServiceAccess.user_id == user_id, - UserServiceAccess.service_id == service_id, - ) - return db.scalar(q) is not None - - -def docker_client(): - return docker.from_env() - - -def session_router_name(session_id: str) -> str: - return f"sess-{session_id.replace('-', '')[:16]}" - - -def _is_pool_name_conflict(exc: Exception) -> bool: - msg = str(exc).lower() - return ("already in use" in msg) or ("marked for removal" in msg) - - -def _remove_container_by_name(d, name: str) -> None: - try: - old = d.containers.get(name) - old.remove(force=True) - except docker.errors.NotFound: - return - except Exception: - logger.exception("pool_container_remove_failed name=%s", name) - - -def ensure_universal_pool() -> None: - if UNIVERSAL_POOL_SIZE <= 0: - return - d = docker_client() - image = "portal-universal-runtime:latest" - - for i in range(UNIVERSAL_POOL_SIZE, 100): - name = universal_container_name(i) - try: - c = d.containers.get(name) - c.stop(timeout=5) - except docker.errors.NotFound: - break - except Exception: - logger.exception("universal_pool_scale_down_failed slot=%s", i) - - for i in range(UNIVERSAL_POOL_SIZE): - name = universal_container_name(i) - path = f"/u/{i}/" - router = f"upool-{i}" - labels = { - "traefik.enable": "true", - "traefik.docker.network": "portal_net", - f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", - f"traefik.http.routers.{router}.entrypoints": "websecure", - f"traefik.http.routers.{router}.tls": "true", - f"traefik.http.routers.{router}.priority": "9400", - f"traefik.http.routers.{router}.middlewares": f"{router}-strip", - f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], - f"traefik.http.services.{router}.loadbalancer.server.port": "6080", - "portal.pool": "1", - "portal.pool.slot": str(i), - } - env = { - "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), - "ENABLE_HEARTBEAT": "0", - "SESSION_ID": f"universal-{i}", - "X11VNC_FLAGS": X11VNC_FLAGS, - } - try: - c = d.containers.get(name) - if c.status != "running": - c.start() - continue - except docker.errors.NotFound: - pass - except Exception: - logger.exception("universal_pool_check_failed slot=%s", i) - continue - - d.containers.run( - image=image, - name=name, - detach=True, - auto_remove=True, - network="portal_net", - labels=labels, - environment=env, - ) - logger.info("universal_pool_container_started slot=%s", i) - - -def ensure_web_pool(target_size: Optional[int] = None) -> None: - desired = max(0, WEB_POOL_SIZE if target_size is None else target_size) - d = docker_client() - image = "portal-universal-runtime:latest" - - for i in range(desired, 100): - name = web_pool_container_name(i) - try: - c = d.containers.get(name) - c.stop(timeout=5) - except docker.errors.NotFound: - break - except Exception: - logger.exception("web_pool_scale_down_failed slot=%s", i) - - for i in range(desired): - name = web_pool_container_name(i) - path = f"/w/{i}/" - router = f"wpool-{i}" - labels = { - "traefik.enable": "true", - "traefik.docker.network": "portal_net", - f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", - f"traefik.http.routers.{router}.entrypoints": "websecure", - f"traefik.http.routers.{router}.tls": "true", - f"traefik.http.routers.{router}.priority": "9450", - f"traefik.http.routers.{router}.middlewares": f"{router}-strip", - f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], - f"traefik.http.services.{router}.loadbalancer.server.port": "6080", - "portal.pool": "1", - "portal.pool.kind": "web", - "portal.pool.slot": str(i), - } - env = { - "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), - "ENABLE_HEARTBEAT": "0", - "SESSION_ID": f"webpool-{i}", - "X11VNC_FLAGS": X11VNC_FLAGS, - } - should_create = False - try: - c = d.containers.get(name) - if c.status != "running": - try: - c.start() - except docker.errors.APIError as exc: - if _is_pool_name_conflict(exc): - logger.warning("web_pool_recreate_needed slot=%s reason=name-conflict", i) - _remove_container_by_name(d, name) - should_create = True - else: - raise - if not should_create: - continue - except docker.errors.NotFound: - should_create = True - except Exception: - logger.exception("web_pool_check_failed slot=%s", i) - continue - - for attempt in range(3): - try: - d.containers.run( - image=image, - name=name, - detach=True, - auto_remove=True, - network="portal_net", - labels=labels, - environment=env, - ) - logger.info("web_pool_container_started slot=%s", i) - break - except docker.errors.APIError as exc: - if _is_pool_name_conflict(exc) and attempt < 2: - logger.warning("web_pool_run_conflict_retry slot=%s attempt=%s", i, attempt + 1) - _remove_container_by_name(d, name) - time.sleep(0.25) - continue - logger.exception("web_pool_run_failed slot=%s", i) - break - - -def get_universal_pool_status() -> dict: - desired = max(0, UNIVERSAL_POOL_SIZE) - if desired <= 0: - return {"desired": 0, "running": 0, "total": 0, "health": "down", "names": []} - d = docker_client() - names = [universal_container_name(i) for i in range(desired)] - containers = [] - for name in names: - try: - containers.append(d.containers.get(name)) - except Exception: - continue - running = sum(1 for c in containers if c.status == "running") - health = "ok" if running >= min(desired, 1) else "down" - return { - "desired": desired, - "running": running, - "total": len(containers), - "names": sorted(c.name for c in containers), - "health": health, - } - - -def get_web_pool_status() -> dict: - desired = max(0, WEB_POOL_SIZE) - if desired <= 0: - return {"desired": 0, "running": 0, "total": 0, "health": "down", "names": []} - d = docker_client() - names = [web_pool_container_name(i) for i in range(desired)] - containers = [] - for name in names: - try: - containers.append(d.containers.get(name)) - except Exception: - continue - running = sum(1 for c in containers if c.status == "running") - health = "ok" if running >= min(desired, 1) else "down" - return { - "desired": desired, - "running": running, - "total": len(containers), - "names": sorted(c.name for c in containers), - "health": health, - } - - -def acquire_universal_slot(db: Session) -> int: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = select(SessionModel).where( - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.container_id.like("POOLIDX:%"), - SessionModel.last_access_at >= cutoff, - ) - active = db.scalars(q).all() - busy = set() - for sess in active: - try: - busy.add(int((sess.container_id or "").split(":", 1)[1])) - except Exception: - continue - for i in range(max(0, UNIVERSAL_POOL_SIZE)): - if i not in busy: - return i - if active: - victim = min(active, key=lambda s: s.last_access_at) - victim.status = SessionStatus.TERMINATED - db.commit() - try: - return int((victim.container_id or "").split(":", 1)[1]) - except Exception: - pass - return 0 - - -def acquire_web_pool_slot(db: Session) -> int: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = select(SessionModel).where( - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.container_id.like("WEBPOOLIDX:%"), - SessionModel.last_access_at >= cutoff, - ) - active = db.scalars(q).all() - busy = set() - for sess in active: - try: - busy.add(int((sess.container_id or "").split(":", 1)[1])) - except Exception: - continue - - # Keep headroom: when active sessions are close to hot pool capacity, - # proactively warm up extra slots. - auto_target = max(WEB_POOL_SIZE, len(active) + max(0, WEB_POOL_BUFFER)) - if auto_target > WEB_POOL_SIZE: - ensure_web_pool(auto_target) - - for i in range(max(0, auto_target)): - if i not in busy: - return i - return 0 - - -def sanitize_client_resolution(width: Optional[int], height: Optional[int]) -> tuple[Optional[int], Optional[int]]: - if width is None or height is None: - return None, None - clamped_width = max(WEB_RESOLUTION_MIN_WIDTH, min(int(width), WEB_RESOLUTION_MAX_WIDTH)) - clamped_height = max(WEB_RESOLUTION_MIN_HEIGHT, min(int(height), WEB_RESOLUTION_MAX_HEIGHT)) - return clamped_width, clamped_height - - -def dispatch_universal_target(slot: int, service: Service, width: Optional[int] = None, height: Optional[int] = None) -> None: - name = universal_container_name(slot) - url = "" - payload = {} - if service.type == ServiceType.WEB: - url = f"http://{name}:7000/open" - payload = {"url": normalize_web_target(service.target), "login": service.svc_login or "", "password": service.svc_password or ""} - width, height = sanitize_client_resolution(width, height) - if width and height: - payload["width"] = width - payload["height"] = height - elif service.type == ServiceType.RDP: - cfg = parse_rdp_target(service.target) - url = f"http://{name}:7000/rdp" - payload = { - "host": cfg["host"], - "port": cfg["port"], - "user": cfg["user"], - "password": cfg["password"], - "domain": cfg["domain"], - "security": cfg["security"], - } - else: - raise HTTPException(status_code=400, detail="Universal pool supports WEB/RDP only") - - last_exc = None - for _ in range(max(1, POOL_DISPATCH_RETRIES)): - try: - resp = requests.post(url, json=payload, timeout=POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS) - resp.raise_for_status() - return - except Exception as exc: - last_exc = exc - time.sleep(max(0.0, POOL_DISPATCH_SLEEP_SECONDS)) - if last_exc: - raise last_exc - - -def dispatch_web_pool_target(slot: int, service: Service, width: Optional[int] = None, height: Optional[int] = None) -> None: - name = web_pool_container_name(slot) - target_url = normalize_web_target(service.target) - url = f"http://{name}:7000/open" - payload = {"url": target_url, "login": service.svc_login or "", "password": service.svc_password or ""} - width, height = sanitize_client_resolution(width, height) - if width and height: - payload["width"] = width - payload["height"] = height - last_exc = None - for _ in range(max(1, POOL_DISPATCH_RETRIES)): - try: - resp = requests.post(url, json=payload, timeout=POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS) - resp.raise_for_status() - return - except Exception as exc: - last_exc = exc - time.sleep(max(0.0, POOL_DISPATCH_SLEEP_SECONDS)) - if last_exc: - raise last_exc - - -def create_runtime_container(service: Service, session_id: str): - d = docker_client() - router = session_router_name(session_id) - path = f"/s/{session_id}/" - labels = { - "traefik.enable": "true", - "traefik.docker.network": "portal_net", - f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", - f"traefik.http.routers.{router}.entrypoints": "websecure", - f"traefik.http.routers.{router}.tls": "true", - f"traefik.http.routers.{router}.priority": "10000", - f"traefik.http.routers.{router}.middlewares": f"{router}-strip", - f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], - f"traefik.http.services.{router}.loadbalancer.server.port": "6080", - } - - env = { - "SESSION_ID": session_id, - "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), - "ENABLE_HEARTBEAT": "1", - "TOUCH_PATH": f"/api/sessions/{session_id}/touch", - "X11VNC_FLAGS": X11VNC_FLAGS, - } - image = "portal-kiosk:latest" - - if service.type == ServiceType.WEB: - env["TARGET_URL"] = service.target - env["HOME_URL"] = f"https://{PUBLIC_HOST}/" - elif service.type == ServiceType.RDP: - image = "portal-rdp-proxy:latest" - cfg = parse_rdp_target(service.target) - env["RDP_HOST"] = cfg["host"] - env["RDP_PORT"] = cfg["port"] - if cfg["user"]: - env["RDP_USER"] = cfg["user"] - if cfg["password"]: - env["RDP_PASSWORD"] = cfg["password"] - if cfg["domain"]: - env["RDP_DOMAIN"] = cfg["domain"] - if cfg["security"]: - env["RDP_SECURITY"] = cfg["security"] - else: - raise HTTPException(status_code=400, detail="Unsupported service type") - - container = d.containers.run( - image=image, - name=f"portal-sess-{session_id[:8]}", - detach=True, - auto_remove=True, - network="portal_net", - labels=labels, - environment=env, - ) - logger.info("session_container_started session_id=%s container_id=%s service_type=%s", session_id, container.id, service.type.value) - return container.id - - -def ensure_warm_pool(service: Service, pool_size: Optional[int] = None) -> None: - if service_uses_universal_pool(service): - return - if pool_size is None: - pool_size = desired_pool_size(service) - if pool_size <= 0: - # Stop stale warm containers for this service when pool is disabled. - prefix = f"portal-warm-{service.slug}-" - try: - d = docker_client() - for c in d.containers.list(all=True, filters={"name": prefix}): - if c.name.startswith(prefix): - c.stop(timeout=5) - except Exception: - logger.exception("warm_pool_disable_failed service=%s", service.slug) - return - d = docker_client() - router = f"warm-{service.slug}" - svc_name = f"warmsvc-{service.slug}" - path = f"/svc/{service.slug}/" - image = "portal-kiosk:latest" - base_env = { - "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), - "ENABLE_HEARTBEAT": "0", - "TOUCH_PATH": "", - "X11VNC_FLAGS": X11VNC_FLAGS, - } - if service.type == ServiceType.WEB: - base_env["UNIVERSAL_WEB"] = "1" - base_env["START_URL"] = normalize_web_target(service.target) - base_env["HOME_URL"] = f"https://{PUBLIC_HOST}/" - elif service.type == ServiceType.RDP: - image = "portal-rdp-proxy:latest" - cfg = parse_rdp_target(service.target) - base_env["RDP_HOST"] = cfg["host"] - base_env["RDP_PORT"] = cfg["port"] - if cfg["user"]: - base_env["RDP_USER"] = cfg["user"] - if cfg["password"]: - base_env["RDP_PASSWORD"] = cfg["password"] - if cfg["domain"]: - base_env["RDP_DOMAIN"] = cfg["domain"] - if cfg["security"]: - base_env["RDP_SECURITY"] = cfg["security"] - else: - raise HTTPException(status_code=400, detail="Unsupported service type") - - labels = { - "traefik.enable": "true", - "traefik.docker.network": "portal_net", - f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", - f"traefik.http.routers.{router}.entrypoints": "websecure", - f"traefik.http.routers.{router}.tls": "true", - f"traefik.http.routers.{router}.priority": "9500", - f"traefik.http.routers.{router}.middlewares": f"{router}-strip", - f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], - f"traefik.http.services.{svc_name}.loadbalancer.server.port": "6080", - f"traefik.http.routers.{router}.service": svc_name, - "portal.warm": "1", - "portal.service.slug": service.slug, - "portal.service.type": service.type.value, - } - - # Ensure desired cardinality. - for i in range(pool_size, 50): - name = f"portal-warm-{service.slug}-{i}" - try: - c = d.containers.get(name) - c.stop(timeout=5) - except docker.errors.NotFound: - break - except Exception: - logger.exception("warm_pool_scale_down_failed service=%s idx=%s", service.slug, i) - - for i in range(pool_size): - name = f"portal-warm-{service.slug}-{i}" - try: - c = d.containers.get(name) - if c.status != "running": - c.start() - continue - except docker.errors.NotFound: - pass - except Exception: - logger.exception("warm_pool_check_failed service=%s idx=%s", service.slug, i) - continue - - env = dict(base_env) - env["SESSION_ID"] = f"warm-{service.slug}-{i}" - d.containers.run( - image=image, - name=name, - detach=True, - auto_remove=True, - network="portal_net", - labels=labels, - environment=env, - ) - logger.info("warm_pool_container_started service=%s idx=%s", service.slug, i) - - -def wait_for_session_route(session_id: str, timeout_seconds: int = 6) -> bool: - target = f"{TRAEFIK_INTERNAL_URL}/s/{session_id}/" - deadline = time.time() + timeout_seconds - while time.time() < deadline: - try: - resp = requests.get( - target, - headers={"Host": PUBLIC_HOST}, - allow_redirects=False, - timeout=1.5, - ) - if resp.status_code != 404: - return True - except Exception: - pass - time.sleep(0.3) - return False - - -def route_ready(path: str) -> bool: - bases = [TRAEFIK_INTERNAL_URL] - if TRAEFIK_INTERNAL_URL.startswith("http://"): - bases.append("https://" + TRAEFIK_INTERNAL_URL[len("http://"):]) - for base in bases: - try: - verify = not base.startswith("https://") - resp = requests.get( - f"{base}{path}", - headers={"Host": PUBLIC_HOST}, - allow_redirects=False, - timeout=1.5, - verify=verify, - ) - if resp.status_code != 404: - return True - except Exception: - continue - return False - - -def container_running(container_id: Optional[str]) -> bool: - if not container_id: - return False - if ( - container_id.startswith("POOL:") - or container_id.startswith("POOLIDX:") - or container_id.startswith("WEBPOOLIDX:") - ): - return True - if container_id.startswith("RDPSLOT:"): - try: - slot_id = int(container_id.split(":", 1)[1]) - db = SessionLocal() - try: - slot = db.get(RdpSlot, slot_id) - if not slot or not slot.container_name: - return False - c = docker_client().containers.get(slot.container_name) - return c.status == "running" - finally: - db.close() - except Exception: - return False - try: - c = docker_client().containers.get(container_id) - return c.status == "running" - except Exception: - return False - - -def _rdp_slot_container_name(service_slug: str, slot_id: int) -> str: - return f"portal-rdpslot-{service_slug}-{slot_id}" - - -def start_rdp_slot_container(slot: RdpSlot, service: Service) -> str: - d = docker_client() - name = _rdp_slot_container_name(service.slug, slot.id) - slot_id = slot.id - path = f"/rdp/{slot_id}/" - router = f"rdpslot-{slot_id}" - labels = { - "traefik.enable": "true", - "traefik.docker.network": "portal_net", - f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", - f"traefik.http.routers.{router}.entrypoints": "websecure", - f"traefik.http.routers.{router}.tls": "true", - f"traefik.http.routers.{router}.priority": "10000", - f"traefik.http.routers.{router}.middlewares": f"{router}-strip", - f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], - f"traefik.http.services.{router}.loadbalancer.server.port": "6080", - "portal.rdpslot": "1", - "portal.rdpslot.id": str(slot_id), - "portal.service.slug": service.slug, - } - cfg = parse_rdp_target(service.target) - env = { - "SESSION_ID": f"rdpslot-{slot_id}", - "IDLE_TIMEOUT": "86400", - "ENABLE_HEARTBEAT": "0", - "RDP_HOST": cfg["host"], - "RDP_PORT": cfg["port"], - "X11VNC_FLAGS": X11VNC_FLAGS, - } - if slot.rdp_username: - env["RDP_USER"] = slot.rdp_username - if slot.rdp_password: - env["RDP_PASSWORD"] = slot.rdp_password - if cfg.get("domain"): - env["RDP_DOMAIN"] = cfg["domain"] - if cfg.get("security"): - env["RDP_SECURITY"] = cfg["security"] - - try: - existing = d.containers.get(name) - existing.stop(timeout=5) - existing.remove(force=True) - except docker.errors.NotFound: - pass - except Exception: - logger.exception("rdp_slot_container_cleanup_failed slot_id=%s", slot_id) - - container = d.containers.run( - "portal-rdp-proxy:latest", - name=name, - detach=True, - restart_policy={"Name": "unless-stopped"}, - network="portal_net", - labels=labels, - environment=env, - ) - logger.info("rdp_slot_container_started slot_id=%s name=%s", slot_id, name) - return container.name - - -def stop_rdp_slot_container(container_name: str) -> None: - if not container_name: - return - try: - d = docker_client() - c = d.containers.get(container_name) - c.stop(timeout=5) - c.remove(force=True) - logger.info("rdp_slot_container_stopped name=%s", container_name) - except docker.errors.NotFound: - pass - except Exception: - logger.exception("rdp_slot_container_stop_failed container=%s", container_name) - - -def _restart_rdp_slot_bg(slot_id: int) -> None: - db = SessionLocal() - try: - slot = db.get(RdpSlot, slot_id) - if not slot or not slot.container_name: - return - service = db.get(Service, slot.service_id) - if not service: - return - try: - d = docker_client() - c = d.containers.get(slot.container_name) - c.restart(timeout=10) - logger.info("rdp_slot_container_restarted slot_id=%s", slot_id) - except docker.errors.NotFound: - start_rdp_slot_container(slot, service) - except Exception: - logger.exception("rdp_slot_container_restart_failed slot_id=%s", slot_id) - finally: - db.close() - - -def stop_runtime_container(container_id: Optional[str]) -> None: - if not container_id: - return - try: - d = docker_client() - c = d.containers.get(container_id) - c.stop(timeout=5) - except Exception: - logger.exception("session_container_stop_failed container_id=%s", container_id) - - -def terminate_session_record( - db: Session, - sess: SessionModel, - new_status: SessionStatus = SessionStatus.TERMINATED, - *, - stop_container: bool = True, -) -> None: - if not sess or sess.status != SessionStatus.ACTIVE: - return - old_status = sess.status - cid = sess.container_id or "" - if stop_container and cid and not cid.startswith(("POOL:", "POOLIDX:", "WEBPOOLIDX:", "RDPSLOT:")): - stop_runtime_container(cid) - if cid.startswith("RDPSLOT:"): - try: - slot_id = int(cid.split(":", 1)[1]) - threading.Thread(target=_restart_rdp_slot_bg, args=(slot_id,), daemon=True).start() - except Exception: - logger.exception("rdp_slot_restart_schedule_failed cid=%s", cid) - sess.status = new_status - sess.last_access_at = now_utc() - log_event( - "session_closed", - level=logging.INFO, - session_id=sess.id, - user_id=sess.user_id, - service_id=sess.service_id, - container_id=cid, - old_status=old_status.value if isinstance(old_status, SessionStatus) else str(old_status), - new_status=new_status.value, - reason=session_closed_reason(sess, db), - stop_container=stop_container, - ) - - -def ensure_schema_compatibility() -> None: - # PostgreSQL requires enum value addition to be committed before usage in constraints. - with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: - conn.execute( - text( - """ - DO $$ - BEGIN - BEGIN - ALTER TYPE servicetype ADD VALUE IF NOT EXISTS 'RDP'; - EXCEPTION WHEN undefined_object THEN - NULL; - END; - END $$; - """ - ) - ) - conn.execute( - text( - """ - DO $$ - BEGIN - BEGIN - ALTER TYPE sessionstatus ADD VALUE IF NOT EXISTS 'ROTATED'; - EXCEPTION WHEN undefined_object THEN - NULL; - END; - END $$; - """ - ) - ) - - with engine.begin() as conn: - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS warm_pool_size INTEGER NOT NULL DEFAULT 0")) - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS comment TEXT NOT NULL DEFAULT ''")) - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_login VARCHAR(256) NOT NULL DEFAULT ''")) - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_password VARCHAR(256) NOT NULL DEFAULT ''")) - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_cred_hint TEXT NOT NULL DEFAULT ''")) - conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS icon_path TEXT NOT NULL DEFAULT ''")) - conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS categories ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NOT NULL UNIQUE, - slug VARCHAR(64) NOT NULL UNIQUE, - created_at TIMESTAMPTZ NOT NULL DEFAULT now() - ) - """ - ) - ) - conn.execute( - text( - """ - CREATE TABLE IF NOT EXISTS service_categories ( - id SERIAL PRIMARY KEY, - service_id INT NOT NULL REFERENCES services(id) ON DELETE CASCADE, - category_id INT NOT NULL REFERENCES categories(id) ON DELETE CASCADE, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - UNIQUE (service_id, category_id) - ) - """ - ) - ) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_service_categories_service_id ON service_categories(service_id)")) - conn.execute(text("CREATE INDEX IF NOT EXISTS idx_service_categories_category_id ON service_categories(category_id)")) - # Handle installs where service type is VARCHAR + CHECK. - conn.execute( - text( - """ - DO $$ - DECLARE c record; - BEGIN - FOR c IN - SELECT conname - FROM pg_constraint - WHERE conrelid = 'services'::regclass - AND contype = 'c' - AND pg_get_constraintdef(oid) ILIKE '%type%' - LOOP - EXECUTE format('ALTER TABLE services DROP CONSTRAINT %I', c.conname); - END LOOP; - ALTER TABLE services - ADD CONSTRAINT services_type_check - CHECK (type IN ('WEB','VNC','RDP')); - EXCEPTION WHEN duplicate_object THEN - NULL; - END $$; - """ - ) - ) - - -def desired_pool_size(service: Service) -> int: - if not service.active: - return 0 - if service.type == ServiceType.RDP and not service_uses_universal_pool(service): - # RDP runs on-demand per user session; no prewarmed pool. - return 0 - if service_uses_universal_pool(service): - return UNIVERSAL_POOL_SIZE - return service.warm_pool_size if service.warm_pool_size and service.warm_pool_size > 0 else PREWARM_POOL_SIZE - - -def get_warm_containers_for_service(service: Service) -> list: - prefix = f"portal-warm-{service.slug}-" - try: - d = docker_client() - containers = [] - for c in d.containers.list(all=True, filters={"name": prefix}): - if c.name.startswith(prefix): - containers.append(c) - return containers - except Exception: - logger.exception("pool_status_failed service=%s", service.slug) - return [] - - -def get_pool_status_for_service(service: Service) -> dict: - if service.type == ServiceType.WEB: - return get_web_pool_status() - if service.type == ServiceType.RDP and not service_uses_universal_pool(service): - return {"desired": 0, "running": 0, "total": 0, "names": [], "health": "n/a"} - if service_uses_universal_pool(service): - return get_universal_pool_status() - desired = desired_pool_size(service) - containers = get_warm_containers_for_service(service) - running = sum(1 for c in containers if c.status == "running") - states = [(c.attrs.get("State") or {}).get("Status", c.status) for c in containers] - has_bad = any(s in {"exited", "dead"} for s in states) - total = len(containers) - if running == 0: - health = "down" - elif running >= min(desired, 1) and not has_bad: - health = "ok" - else: - health = "degraded" - return { - "desired": desired, - "running": running, - "total": total, - "names": sorted(c.name for c in containers), - "health": health, - } - - -def get_pool_detailed_status(service: Service) -> dict: - if service.type == ServiceType.WEB: - d = docker_client() - pool = get_web_pool_status() - details = [] - for i in range(max(0, pool["desired"])): - name = web_pool_container_name(i) - try: - c = d.containers.get(name) - except Exception: - continue - attrs = c.attrs or {} - state = (attrs.get("State") or {}).get("Status", c.status) - details.append( - { - "name": c.name, - "status": c.status, - "state": state, - "created": attrs.get("Created", ""), - "image": c.image.tags[0] if c.image.tags else "", - "labels_ok": True, - } - ) - return { - "service_id": service.id, - "slug": service.slug, - "type": service.type.value, - "desired": pool["desired"], - "running": pool["running"], - "total": pool["total"], - "health": pool["health"], - "containers": details, - "updated_at": now_utc().isoformat(), - } - if service_uses_universal_pool(service): - d = docker_client() - pool = get_universal_pool_status() - details = [] - for i in range(max(0, UNIVERSAL_POOL_SIZE)): - name = universal_container_name(i) - try: - c = d.containers.get(name) - except Exception: - continue - attrs = c.attrs or {} - state = (attrs.get("State") or {}).get("Status", c.status) - details.append( - { - "name": c.name, - "status": c.status, - "state": state, - "created": attrs.get("Created", ""), - "image": c.image.tags[0] if c.image.tags else "", - "labels_ok": True, - } - ) - return { - "service_id": service.id, - "slug": service.slug, - "type": service.type.value, - "desired": pool["desired"], - "running": pool["running"], - "total": pool["total"], - "health": pool["health"], - "containers": details, - "updated_at": now_utc().isoformat(), - } - containers = get_warm_containers_for_service(service) - pool = get_pool_status_for_service(service) - details = [] - for c in sorted(containers, key=lambda x: x.name): - attrs = c.attrs or {} - state = (attrs.get("State") or {}).get("Status", c.status) - created = attrs.get("Created", "") - labels = attrs.get("Config", {}).get("Labels", {}) or {} - labels_ok = ( - labels.get("portal.warm") == "1" - and labels.get("portal.service.slug") == service.slug - and labels.get("portal.service.type") == service.type.value - ) - details.append( - { - "name": c.name, - "status": c.status, - "state": state, - "created": created, - "image": c.image.tags[0] if c.image.tags else "", - "labels_ok": labels_ok, - } - ) - return { - "service_id": service.id, - "slug": service.slug, - "type": service.type.value, - "desired": pool["desired"], - "running": pool["running"], - "total": pool["total"], - "health": pool["health"], - "containers": details, - "updated_at": now_utc().isoformat(), - } - - -def get_active_sessions_count(db: Session, service_id: int) -> int: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = select(SessionModel).where( - SessionModel.service_id == service_id, - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.last_access_at >= cutoff, - ) - sessions = db.scalars(q).all() - # Avoid inflated stats when pooled slot sessions were duplicated by race: - # for pooled sessions, occupancy is unique container_id. - pooled = [s for s in sessions if (s.container_id or "").startswith(("WEBPOOLIDX:", "POOLIDX:", "POOL:"))] - direct = [s for s in sessions if s not in pooled] - unique_pooled = len({s.container_id for s in pooled if s.container_id}) - return unique_pooled + len(direct) - - -def find_active_session_for_service(db: Session, service_id: int) -> Optional[SessionModel]: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = ( - select(SessionModel) - .where( - SessionModel.service_id == service_id, - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.last_access_at >= cutoff, - ) - .order_by(SessionModel.created_at.desc()) - ) - return db.scalars(q).first() - - -def find_active_session_for_user_service(db: Session, user_id: int, service_id: int) -> Optional[SessionModel]: - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = ( - select(SessionModel) - .where( - SessionModel.user_id == user_id, - SessionModel.service_id == service_id, - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.last_access_at >= cutoff, - ) - .order_by(SessionModel.created_at.desc()) - ) - return db.scalars(q).first() - - -class LockTimeoutError(Exception): - pass - - -def allocator_lock(db: Session, lock_id: int, timeout_seconds: Optional[float] = None, poll_seconds: float = 0.05): - class _LockCtx: - def __enter__(self_nonlocal): - self_nonlocal._acquired = False - if timeout_seconds is None: - db.execute(text("SELECT pg_advisory_xact_lock(:lid)"), {"lid": lock_id}) - self_nonlocal._acquired = True - return self_nonlocal - - deadline = time.monotonic() + max(0.0, timeout_seconds) - while time.monotonic() <= deadline: - got = db.execute(text("SELECT pg_try_advisory_xact_lock(:lid)"), {"lid": lock_id}).scalar() - if got: - self_nonlocal._acquired = True - return self_nonlocal - time.sleep(max(0.01, poll_seconds)) - raise LockTimeoutError(f"advisory lock timeout lock_id={lock_id} timeout={timeout_seconds}") - - return self_nonlocal - - def __exit__(self_nonlocal, exc_type, exc, tb): - return False - - return _LockCtx() - - -def terminate_active_slot_sessions(db: Session, container_id: str) -> None: - if not container_id: - return - db.execute( - text( - """ - UPDATE sessions - SET status = 'TERMINATED' - WHERE container_id = :cid - AND status = 'ACTIVE' - """ - ), - {"cid": container_id}, - ) - - -def session_redirect_url(sess: SessionModel) -> str: - cid = sess.container_id or "" - if cid.startswith("POOL:") or cid.startswith("POOLIDX:") or cid.startswith("WEBPOOLIDX:") or cid.startswith("RDPSLOT:"): - return f"/s/{sess.id}/view" - return f"/s/{sess.id}/" - - -def open_warm_web_url(service: Service, target_url: str) -> None: - if service_uses_universal_pool(service): - return - if service.type != ServiceType.WEB: - return - target_url = normalize_web_target(target_url) - try: - d = docker_client() - containers = d.containers.list( - filters={ - "label": [ - "portal.warm=1", - f"portal.service.slug={service.slug}", - "portal.service.type=WEB", - ] - } - ) - for c in containers: - try: - resp = requests.post( - f"http://{c.name}:7000/open", - json={"url": target_url}, - timeout=2, - ) - resp.raise_for_status() - logger.info("warm_web_open_ok service=%s container=%s url=%s", service.slug, c.name, target_url) - except Exception: - logger.exception("warm_web_open_failed service=%s container=%s", service.slug, c.name) - except Exception: - logger.exception("warm_web_open_dispatch_failed service=%s", service.slug) - - -def cleanup_loop(): - while True: - time.sleep(60) - db = SessionLocal() - try: - ensure_universal_pool() - ensure_web_pool() - for svc in db.scalars( - select(Service).where( - Service.active == True, - Service.type.in_([ServiceType.WEB, ServiceType.RDP]), - ) - ).all(): - if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: - ensure_warm_pool(svc) - cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) - q = select(SessionModel).where( - SessionModel.status == SessionStatus.ACTIVE, - SessionModel.last_access_at < cutoff, - ) - stale = db.scalars(q).all() - rdp_slots_to_restart: list[int] = [] - for sess in stale: - cid = sess.container_id or "" - if cid.startswith("RDPSLOT:"): - try: - rdp_slots_to_restart.append(int(cid.split(":", 1)[1])) - except Exception: - pass - elif cid and not ( - cid.startswith("POOL:") - or cid.startswith("POOLIDX:") - or cid.startswith("WEBPOOLIDX:") - ): - stop_runtime_container(cid) - sess.status = SessionStatus.EXPIRED - if stale: - db.commit() - for slot_id in rdp_slots_to_restart: - threading.Thread(target=_restart_rdp_slot_bg, args=(slot_id,), daemon=True).start() - except Exception: - db.rollback() - logger.exception("cleanup_loop_failed") - finally: - db.close() - - -def bootstrap_admin(): - admin_user = os.getenv("ADMIN_USERNAME", "admin") - admin_password = os.getenv("ADMIN_PASSWORD", "change_me") - ttl_days = int(os.getenv("ADMIN_TTL_DAYS", "3650")) - - db = SessionLocal() - try: - existing = db.scalar(select(User).where(User.username == admin_user)) - if not existing: - db.add( - User( - username=admin_user, - password_hash=hash_password(admin_password), - active=True, - is_admin=True, - expires_at=now_utc() + dt.timedelta(days=ttl_days), - ) - ) - db.commit() - finally: - db.close() - - -def try_acquire_maintenance_leader() -> bool: - global maintenance_lock_file - if maintenance_lock_file is not None: - return True - lock_file = open("/tmp/portal-maintenance.lock", "w") - try: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) - except BlockingIOError: - lock_file.close() - return False - maintenance_lock_file = lock_file - return True - - - -def run_maintenance_service() -> None: - logger.info("maintenance_service_bootstrap_started") - with open("/tmp/portal-schema.lock", "w") as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - Base.metadata.create_all(bind=engine) - ensure_schema_compatibility() - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) - - ensure_icons_dir() - bootstrap_admin() - - maintenance_lock = open("/tmp/portal-maintenance.lock", "w") - fcntl.flock(maintenance_lock.fileno(), fcntl.LOCK_EX) - logger.info("maintenance_service_leader_acquired") - - db = SessionLocal() - try: - ensure_universal_pool() - ensure_web_pool() - for svc in db.scalars( - select(Service).where( - Service.active == True, - Service.type.in_([ServiceType.WEB, ServiceType.RDP]), - ) - ).all(): - if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: - ensure_warm_pool(svc) - finally: - db.close() - - logger.info("maintenance_service_loop_started") - cleanup_loop() @app.on_event("startup") def startup_event(): - # Multiple uvicorn workers run startup in parallel. Serialize schema bootstrap - # to avoid DDL races on first run and during schema extension. - with open("/tmp/portal-schema.lock", "w") as lock_file: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) - Base.metadata.create_all(bind=engine) - ensure_schema_compatibility() - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) - ensure_icons_dir() - bootstrap_admin() - if not try_acquire_maintenance_leader(): - logger.info("maintenance_leader_skipped") - return - - if ENABLE_STARTUP_MAINTENANCE: - db = SessionLocal() - try: - ensure_universal_pool() - ensure_web_pool() - for svc in db.scalars( - select(Service).where( - Service.active == True, - Service.type.in_([ServiceType.WEB, ServiceType.RDP]), - ) - ).all(): - if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: - ensure_warm_pool(svc) - elif svc.type == ServiceType.RDP: - slots = db.scalars(select(RdpSlot).where(RdpSlot.service_id == svc.id)).all() - for slot in slots: - try: - cname = _rdp_slot_container_name(svc.slug, slot.id) - try: - c = docker_client().containers.get(cname) - if c.status != "running": - c.start() - except docker.errors.NotFound: - start_rdp_slot_container(slot, svc) - slot.container_name = cname - except Exception: - logger.exception("startup_rdp_slot_start_failed slot_id=%s", slot.id) - if slots: - db.commit() - finally: - db.close() - - thread = threading.Thread(target=cleanup_loop, daemon=True) - thread.start() - logger.info("maintenance_leader_started") + on_startup() @app.get("/", response_class=HTMLResponse) diff --git a/app/maintenance.py b/app/maintenance.py new file mode 100644 index 0000000..7caee2f --- /dev/null +++ b/app/maintenance.py @@ -0,0 +1,196 @@ +import datetime as dt +import fcntl +import logging +import os +import threading +import time + +import docker +from sqlalchemy import select + +from config import ENABLE_STARTUP_MAINTENANCE, SESSION_IDLE_SECONDS, WEB_POOL_SIZE +from database import Base, SessionLocal, engine +from models import RdpSlot, Service, ServiceType, SessionModel, SessionStatus, User +from utils import ensure_icons_dir, now_utc +from auth import hash_password +from runtime import ( + _rdp_slot_container_name, + _restart_rdp_slot_bg, + docker_client, + ensure_schema_compatibility, + ensure_universal_pool, + ensure_warm_pool, + ensure_web_pool, + start_rdp_slot_container, + stop_runtime_container, +) + +logger = logging.getLogger("portal") +maintenance_lock_file = None + + +def cleanup_loop(): + while True: + time.sleep(60) + db = SessionLocal() + try: + ensure_universal_pool() + ensure_web_pool() + for svc in db.scalars( + select(Service).where( + Service.active == True, + Service.type.in_([ServiceType.WEB, ServiceType.RDP]), + ) + ).all(): + if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: + ensure_warm_pool(svc) + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = select(SessionModel).where( + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.last_access_at < cutoff, + ) + stale = db.scalars(q).all() + rdp_slots_to_restart: list[int] = [] + for sess in stale: + cid = sess.container_id or "" + if cid.startswith("RDPSLOT:"): + try: + rdp_slots_to_restart.append(int(cid.split(":", 1)[1])) + except Exception: + pass + elif cid and not ( + cid.startswith("POOL:") + or cid.startswith("POOLIDX:") + or cid.startswith("WEBPOOLIDX:") + ): + stop_runtime_container(cid) + sess.status = SessionStatus.EXPIRED + if stale: + db.commit() + for slot_id in rdp_slots_to_restart: + threading.Thread(target=_restart_rdp_slot_bg, args=(slot_id,), daemon=True).start() + except Exception: + db.rollback() + logger.exception("cleanup_loop_failed") + finally: + db.close() + + +def bootstrap_admin(): + admin_user = os.getenv("ADMIN_USERNAME", "admin") + admin_password = os.getenv("ADMIN_PASSWORD", "change_me") + ttl_days = int(os.getenv("ADMIN_TTL_DAYS", "3650")) + + db = SessionLocal() + try: + existing = db.scalar(select(User).where(User.username == admin_user)) + if not existing: + db.add( + User( + username=admin_user, + password_hash=hash_password(admin_password), + active=True, + is_admin=True, + expires_at=now_utc() + dt.timedelta(days=ttl_days), + ) + ) + db.commit() + finally: + db.close() + + +def try_acquire_maintenance_leader() -> bool: + global maintenance_lock_file + if maintenance_lock_file is not None: + return True + lock_file = open("/tmp/portal-maintenance.lock", "w") + try: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + except BlockingIOError: + lock_file.close() + return False + maintenance_lock_file = lock_file + return True + + +def run_maintenance_service() -> None: + logger.info("maintenance_service_bootstrap_started") + with open("/tmp/portal-schema.lock", "w") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + Base.metadata.create_all(bind=engine) + ensure_schema_compatibility() + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + + ensure_icons_dir() + bootstrap_admin() + + maintenance_lock = open("/tmp/portal-maintenance.lock", "w") + fcntl.flock(maintenance_lock.fileno(), fcntl.LOCK_EX) + logger.info("maintenance_service_leader_acquired") + + db = SessionLocal() + try: + ensure_universal_pool() + ensure_web_pool() + for svc in db.scalars( + select(Service).where( + Service.active == True, + Service.type.in_([ServiceType.WEB, ServiceType.RDP]), + ) + ).all(): + if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: + ensure_warm_pool(svc) + finally: + db.close() + + logger.info("maintenance_service_loop_started") + cleanup_loop() + + +def on_startup() -> None: + with open("/tmp/portal-schema.lock", "w") as lock_file: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX) + Base.metadata.create_all(bind=engine) + ensure_schema_compatibility() + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + ensure_icons_dir() + bootstrap_admin() + if not try_acquire_maintenance_leader(): + logger.info("maintenance_leader_skipped") + return + + if ENABLE_STARTUP_MAINTENANCE: + db = SessionLocal() + try: + ensure_universal_pool() + ensure_web_pool() + for svc in db.scalars( + select(Service).where( + Service.active == True, + Service.type.in_([ServiceType.WEB, ServiceType.RDP]), + ) + ).all(): + if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0: + ensure_warm_pool(svc) + elif svc.type == ServiceType.RDP: + slots = db.scalars(select(RdpSlot).where(RdpSlot.service_id == svc.id)).all() + for slot in slots: + try: + cname = _rdp_slot_container_name(svc.slug, slot.id) + try: + c = docker_client().containers.get(cname) + if c.status != "running": + c.start() + except docker.errors.NotFound: + start_rdp_slot_container(slot, svc) + slot.container_name = cname + except Exception: + logger.exception("startup_rdp_slot_start_failed slot_id=%s", slot.id) + if slots: + db.commit() + finally: + db.close() + + thread = threading.Thread(target=cleanup_loop, daemon=True) + thread.start() + logger.info("maintenance_leader_started") diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..95f729a --- /dev/null +++ b/app/models.py @@ -0,0 +1,115 @@ +import datetime as dt +import enum +from typing import Optional + +from sqlalchemy import ( + Boolean, DateTime, Enum, ForeignKey, Integer, String, Text, UniqueConstraint, +) +from sqlalchemy.orm import Mapped, mapped_column + +from database import Base + + +class ServiceType(str, enum.Enum): + WEB = "WEB" + VNC = "VNC" + RDP = "RDP" + + +class SessionStatus(str, enum.Enum): + ACTIVE = "ACTIVE" + EXPIRED = "EXPIRED" + TERMINATED = "TERMINATED" + ROTATED = "ROTATED" + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + username: Mapped[str] = mapped_column(String(64), unique=True, index=True) + password_hash: Mapped[str] = mapped_column(String(255)) + expires_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), index=True) + active: Mapped[bool] = mapped_column(Boolean, default=True, index=True) + is_admin: Mapped[bool] = mapped_column(Boolean, default=False) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class Service(Base): + __tablename__ = "services" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String(128)) + slug: Mapped[str] = mapped_column(String(64), unique=True, index=True) + type: Mapped[ServiceType] = mapped_column(Enum(ServiceType), index=True) + target: Mapped[str] = mapped_column(Text) + comment: Mapped[str] = mapped_column(Text, default="") + svc_login: Mapped[str] = mapped_column(String(256), default="") + svc_password: Mapped[str] = mapped_column(String(256), default="") + svc_cred_hint: Mapped[str] = mapped_column(Text, default="") + icon_path: Mapped[str] = mapped_column(Text, default="") + active: Mapped[bool] = mapped_column(Boolean, default=True) + warm_pool_size: Mapped[int] = mapped_column(Integer, default=0) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class Category(Base): + __tablename__ = "categories" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String(128), unique=True, index=True) + slug: Mapped[str] = mapped_column(String(64), unique=True, index=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class ServiceCategory(Base): + __tablename__ = "service_categories" + __table_args__ = (UniqueConstraint("service_id", "category_id", name="uq_service_category"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) + category_id: Mapped[int] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE"), index=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class UserServiceAccess(Base): + __tablename__ = "user_service_access" + __table_args__ = (UniqueConstraint("user_id", "service_id", name="uq_user_service"),) + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) + granted_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class RdpSlot(Base): + __tablename__ = "rdp_slots" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) + rdp_username: Mapped[str] = mapped_column(String(128)) + rdp_password: Mapped[str] = mapped_column(String(256), default="") + container_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc)) + + +class SessionModel(Base): + __tablename__ = "sessions" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True) + service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True) + status: Mapped[SessionStatus] = mapped_column(Enum(SessionStatus), default=SessionStatus.ACTIVE, index=True) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) + last_access_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) + container_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) + + +class AuditLog(Base): + __tablename__ = "audit_logs" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + user_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True) + action: Mapped[str] = mapped_column(String(128), index=True) + details: Mapped[str] = mapped_column(Text) + created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True) diff --git a/app/runtime.py b/app/runtime.py new file mode 100644 index 0000000..b66b24c --- /dev/null +++ b/app/runtime.py @@ -0,0 +1,1131 @@ +import datetime as dt +import logging +import threading +import time +from typing import Optional + +import docker +import requests +from fastapi import HTTPException +from sqlalchemy import select, text +from sqlalchemy.orm import Session + +from config import ( + POOL_DISPATCH_RETRIES, POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS, + POOL_DISPATCH_SLEEP_SECONDS, PREWARM_POOL_SIZE, PUBLIC_HOST, + SESSION_IDLE_SECONDS, TRAEFIK_INTERNAL_URL, UNIVERSAL_POOL_SIZE, + WEB_POOL_BUFFER, WEB_POOL_SIZE, X11VNC_FLAGS, + WEB_RESOLUTION_MIN_WIDTH, WEB_RESOLUTION_MIN_HEIGHT, + WEB_RESOLUTION_MAX_WIDTH, WEB_RESOLUTION_MAX_HEIGHT, +) +from database import SessionLocal, engine +from models import ( + RdpSlot, Service, ServiceType, SessionModel, SessionStatus, +) +from utils import log_event, normalize_web_target, now_utc, parse_rdp_target, session_closed_reason + +logger = logging.getLogger("portal") + +def docker_client(): + return docker.from_env() + + +def session_router_name(session_id: str) -> str: + return f"sess-{session_id.replace('-', '')[:16]}" + + +def _is_pool_name_conflict(exc: Exception) -> bool: + msg = str(exc).lower() + return ("already in use" in msg) or ("marked for removal" in msg) + + +def _remove_container_by_name(d, name: str) -> None: + try: + old = d.containers.get(name) + old.remove(force=True) + except docker.errors.NotFound: + return + except Exception: + logger.exception("pool_container_remove_failed name=%s", name) + + +def ensure_universal_pool() -> None: + if UNIVERSAL_POOL_SIZE <= 0: + return + d = docker_client() + image = "portal-universal-runtime:latest" + + for i in range(UNIVERSAL_POOL_SIZE, 100): + name = universal_container_name(i) + try: + c = d.containers.get(name) + c.stop(timeout=5) + except docker.errors.NotFound: + break + except Exception: + logger.exception("universal_pool_scale_down_failed slot=%s", i) + + for i in range(UNIVERSAL_POOL_SIZE): + name = universal_container_name(i) + path = f"/u/{i}/" + router = f"upool-{i}" + labels = { + "traefik.enable": "true", + "traefik.docker.network": "portal_net", + f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", + f"traefik.http.routers.{router}.entrypoints": "websecure", + f"traefik.http.routers.{router}.tls": "true", + f"traefik.http.routers.{router}.priority": "9400", + f"traefik.http.routers.{router}.middlewares": f"{router}-strip", + f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], + f"traefik.http.services.{router}.loadbalancer.server.port": "6080", + "portal.pool": "1", + "portal.pool.slot": str(i), + } + env = { + "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), + "ENABLE_HEARTBEAT": "0", + "SESSION_ID": f"universal-{i}", + "X11VNC_FLAGS": X11VNC_FLAGS, + } + try: + c = d.containers.get(name) + if c.status != "running": + c.start() + continue + except docker.errors.NotFound: + pass + except Exception: + logger.exception("universal_pool_check_failed slot=%s", i) + continue + + d.containers.run( + image=image, + name=name, + detach=True, + auto_remove=True, + network="portal_net", + labels=labels, + environment=env, + ) + logger.info("universal_pool_container_started slot=%s", i) + + +def ensure_web_pool(target_size: Optional[int] = None) -> None: + desired = max(0, WEB_POOL_SIZE if target_size is None else target_size) + d = docker_client() + image = "portal-universal-runtime:latest" + + for i in range(desired, 100): + name = web_pool_container_name(i) + try: + c = d.containers.get(name) + c.stop(timeout=5) + except docker.errors.NotFound: + break + except Exception: + logger.exception("web_pool_scale_down_failed slot=%s", i) + + for i in range(desired): + name = web_pool_container_name(i) + path = f"/w/{i}/" + router = f"wpool-{i}" + labels = { + "traefik.enable": "true", + "traefik.docker.network": "portal_net", + f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", + f"traefik.http.routers.{router}.entrypoints": "websecure", + f"traefik.http.routers.{router}.tls": "true", + f"traefik.http.routers.{router}.priority": "9450", + f"traefik.http.routers.{router}.middlewares": f"{router}-strip", + f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], + f"traefik.http.services.{router}.loadbalancer.server.port": "6080", + "portal.pool": "1", + "portal.pool.kind": "web", + "portal.pool.slot": str(i), + } + env = { + "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), + "ENABLE_HEARTBEAT": "0", + "SESSION_ID": f"webpool-{i}", + "X11VNC_FLAGS": X11VNC_FLAGS, + } + should_create = False + try: + c = d.containers.get(name) + if c.status != "running": + try: + c.start() + except docker.errors.APIError as exc: + if _is_pool_name_conflict(exc): + logger.warning("web_pool_recreate_needed slot=%s reason=name-conflict", i) + _remove_container_by_name(d, name) + should_create = True + else: + raise + if not should_create: + continue + except docker.errors.NotFound: + should_create = True + except Exception: + logger.exception("web_pool_check_failed slot=%s", i) + continue + + for attempt in range(3): + try: + d.containers.run( + image=image, + name=name, + detach=True, + auto_remove=True, + network="portal_net", + labels=labels, + environment=env, + ) + logger.info("web_pool_container_started slot=%s", i) + break + except docker.errors.APIError as exc: + if _is_pool_name_conflict(exc) and attempt < 2: + logger.warning("web_pool_run_conflict_retry slot=%s attempt=%s", i, attempt + 1) + _remove_container_by_name(d, name) + time.sleep(0.25) + continue + logger.exception("web_pool_run_failed slot=%s", i) + break + + +def get_universal_pool_status() -> dict: + desired = max(0, UNIVERSAL_POOL_SIZE) + if desired <= 0: + return {"desired": 0, "running": 0, "total": 0, "health": "down", "names": []} + d = docker_client() + names = [universal_container_name(i) for i in range(desired)] + containers = [] + for name in names: + try: + containers.append(d.containers.get(name)) + except Exception: + continue + running = sum(1 for c in containers if c.status == "running") + health = "ok" if running >= min(desired, 1) else "down" + return { + "desired": desired, + "running": running, + "total": len(containers), + "names": sorted(c.name for c in containers), + "health": health, + } + + +def get_web_pool_status() -> dict: + desired = max(0, WEB_POOL_SIZE) + if desired <= 0: + return {"desired": 0, "running": 0, "total": 0, "health": "down", "names": []} + d = docker_client() + names = [web_pool_container_name(i) for i in range(desired)] + containers = [] + for name in names: + try: + containers.append(d.containers.get(name)) + except Exception: + continue + running = sum(1 for c in containers if c.status == "running") + health = "ok" if running >= min(desired, 1) else "down" + return { + "desired": desired, + "running": running, + "total": len(containers), + "names": sorted(c.name for c in containers), + "health": health, + } + + +def acquire_universal_slot(db: Session) -> int: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = select(SessionModel).where( + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.container_id.like("POOLIDX:%"), + SessionModel.last_access_at >= cutoff, + ) + active = db.scalars(q).all() + busy = set() + for sess in active: + try: + busy.add(int((sess.container_id or "").split(":", 1)[1])) + except Exception: + continue + for i in range(max(0, UNIVERSAL_POOL_SIZE)): + if i not in busy: + return i + if active: + victim = min(active, key=lambda s: s.last_access_at) + victim.status = SessionStatus.TERMINATED + db.commit() + try: + return int((victim.container_id or "").split(":", 1)[1]) + except Exception: + pass + return 0 + + +def acquire_web_pool_slot(db: Session) -> int: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = select(SessionModel).where( + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.container_id.like("WEBPOOLIDX:%"), + SessionModel.last_access_at >= cutoff, + ) + active = db.scalars(q).all() + busy = set() + for sess in active: + try: + busy.add(int((sess.container_id or "").split(":", 1)[1])) + except Exception: + continue + + # Keep headroom: when active sessions are close to hot pool capacity, + # proactively warm up extra slots. + auto_target = max(WEB_POOL_SIZE, len(active) + max(0, WEB_POOL_BUFFER)) + if auto_target > WEB_POOL_SIZE: + ensure_web_pool(auto_target) + + for i in range(max(0, auto_target)): + if i not in busy: + return i + return 0 + + +def sanitize_client_resolution(width: Optional[int], height: Optional[int]) -> tuple[Optional[int], Optional[int]]: + if width is None or height is None: + return None, None + clamped_width = max(WEB_RESOLUTION_MIN_WIDTH, min(int(width), WEB_RESOLUTION_MAX_WIDTH)) + clamped_height = max(WEB_RESOLUTION_MIN_HEIGHT, min(int(height), WEB_RESOLUTION_MAX_HEIGHT)) + return clamped_width, clamped_height + + +def dispatch_universal_target(slot: int, service: Service, width: Optional[int] = None, height: Optional[int] = None) -> None: + name = universal_container_name(slot) + url = "" + payload = {} + if service.type == ServiceType.WEB: + url = f"http://{name}:7000/open" + payload = {"url": normalize_web_target(service.target), "login": service.svc_login or "", "password": service.svc_password or ""} + width, height = sanitize_client_resolution(width, height) + if width and height: + payload["width"] = width + payload["height"] = height + elif service.type == ServiceType.RDP: + cfg = parse_rdp_target(service.target) + url = f"http://{name}:7000/rdp" + payload = { + "host": cfg["host"], + "port": cfg["port"], + "user": cfg["user"], + "password": cfg["password"], + "domain": cfg["domain"], + "security": cfg["security"], + } + else: + raise HTTPException(status_code=400, detail="Universal pool supports WEB/RDP only") + + last_exc = None + for _ in range(max(1, POOL_DISPATCH_RETRIES)): + try: + resp = requests.post(url, json=payload, timeout=POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS) + resp.raise_for_status() + return + except Exception as exc: + last_exc = exc + time.sleep(max(0.0, POOL_DISPATCH_SLEEP_SECONDS)) + if last_exc: + raise last_exc + + +def dispatch_web_pool_target(slot: int, service: Service, width: Optional[int] = None, height: Optional[int] = None) -> None: + name = web_pool_container_name(slot) + target_url = normalize_web_target(service.target) + url = f"http://{name}:7000/open" + payload = {"url": target_url, "login": service.svc_login or "", "password": service.svc_password or ""} + width, height = sanitize_client_resolution(width, height) + if width and height: + payload["width"] = width + payload["height"] = height + last_exc = None + for _ in range(max(1, POOL_DISPATCH_RETRIES)): + try: + resp = requests.post(url, json=payload, timeout=POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS) + resp.raise_for_status() + return + except Exception as exc: + last_exc = exc + time.sleep(max(0.0, POOL_DISPATCH_SLEEP_SECONDS)) + if last_exc: + raise last_exc + + +def create_runtime_container(service: Service, session_id: str): + d = docker_client() + router = session_router_name(session_id) + path = f"/s/{session_id}/" + labels = { + "traefik.enable": "true", + "traefik.docker.network": "portal_net", + f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", + f"traefik.http.routers.{router}.entrypoints": "websecure", + f"traefik.http.routers.{router}.tls": "true", + f"traefik.http.routers.{router}.priority": "10000", + f"traefik.http.routers.{router}.middlewares": f"{router}-strip", + f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], + f"traefik.http.services.{router}.loadbalancer.server.port": "6080", + } + + env = { + "SESSION_ID": session_id, + "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), + "ENABLE_HEARTBEAT": "1", + "TOUCH_PATH": f"/api/sessions/{session_id}/touch", + "X11VNC_FLAGS": X11VNC_FLAGS, + } + image = "portal-kiosk:latest" + + if service.type == ServiceType.WEB: + env["TARGET_URL"] = service.target + env["HOME_URL"] = f"https://{PUBLIC_HOST}/" + elif service.type == ServiceType.RDP: + image = "portal-rdp-proxy:latest" + cfg = parse_rdp_target(service.target) + env["RDP_HOST"] = cfg["host"] + env["RDP_PORT"] = cfg["port"] + if cfg["user"]: + env["RDP_USER"] = cfg["user"] + if cfg["password"]: + env["RDP_PASSWORD"] = cfg["password"] + if cfg["domain"]: + env["RDP_DOMAIN"] = cfg["domain"] + if cfg["security"]: + env["RDP_SECURITY"] = cfg["security"] + else: + raise HTTPException(status_code=400, detail="Unsupported service type") + + container = d.containers.run( + image=image, + name=f"portal-sess-{session_id[:8]}", + detach=True, + auto_remove=True, + network="portal_net", + labels=labels, + environment=env, + ) + logger.info("session_container_started session_id=%s container_id=%s service_type=%s", session_id, container.id, service.type.value) + return container.id + + +def ensure_warm_pool(service: Service, pool_size: Optional[int] = None) -> None: + if service_uses_universal_pool(service): + return + if pool_size is None: + pool_size = desired_pool_size(service) + if pool_size <= 0: + # Stop stale warm containers for this service when pool is disabled. + prefix = f"portal-warm-{service.slug}-" + try: + d = docker_client() + for c in d.containers.list(all=True, filters={"name": prefix}): + if c.name.startswith(prefix): + c.stop(timeout=5) + except Exception: + logger.exception("warm_pool_disable_failed service=%s", service.slug) + return + d = docker_client() + router = f"warm-{service.slug}" + svc_name = f"warmsvc-{service.slug}" + path = f"/svc/{service.slug}/" + image = "portal-kiosk:latest" + base_env = { + "IDLE_TIMEOUT": str(SESSION_IDLE_SECONDS), + "ENABLE_HEARTBEAT": "0", + "TOUCH_PATH": "", + "X11VNC_FLAGS": X11VNC_FLAGS, + } + if service.type == ServiceType.WEB: + base_env["UNIVERSAL_WEB"] = "1" + base_env["START_URL"] = normalize_web_target(service.target) + base_env["HOME_URL"] = f"https://{PUBLIC_HOST}/" + elif service.type == ServiceType.RDP: + image = "portal-rdp-proxy:latest" + cfg = parse_rdp_target(service.target) + base_env["RDP_HOST"] = cfg["host"] + base_env["RDP_PORT"] = cfg["port"] + if cfg["user"]: + base_env["RDP_USER"] = cfg["user"] + if cfg["password"]: + base_env["RDP_PASSWORD"] = cfg["password"] + if cfg["domain"]: + base_env["RDP_DOMAIN"] = cfg["domain"] + if cfg["security"]: + base_env["RDP_SECURITY"] = cfg["security"] + else: + raise HTTPException(status_code=400, detail="Unsupported service type") + + labels = { + "traefik.enable": "true", + "traefik.docker.network": "portal_net", + f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", + f"traefik.http.routers.{router}.entrypoints": "websecure", + f"traefik.http.routers.{router}.tls": "true", + f"traefik.http.routers.{router}.priority": "9500", + f"traefik.http.routers.{router}.middlewares": f"{router}-strip", + f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], + f"traefik.http.services.{svc_name}.loadbalancer.server.port": "6080", + f"traefik.http.routers.{router}.service": svc_name, + "portal.warm": "1", + "portal.service.slug": service.slug, + "portal.service.type": service.type.value, + } + + # Ensure desired cardinality. + for i in range(pool_size, 50): + name = f"portal-warm-{service.slug}-{i}" + try: + c = d.containers.get(name) + c.stop(timeout=5) + except docker.errors.NotFound: + break + except Exception: + logger.exception("warm_pool_scale_down_failed service=%s idx=%s", service.slug, i) + + for i in range(pool_size): + name = f"portal-warm-{service.slug}-{i}" + try: + c = d.containers.get(name) + if c.status != "running": + c.start() + continue + except docker.errors.NotFound: + pass + except Exception: + logger.exception("warm_pool_check_failed service=%s idx=%s", service.slug, i) + continue + + env = dict(base_env) + env["SESSION_ID"] = f"warm-{service.slug}-{i}" + d.containers.run( + image=image, + name=name, + detach=True, + auto_remove=True, + network="portal_net", + labels=labels, + environment=env, + ) + logger.info("warm_pool_container_started service=%s idx=%s", service.slug, i) + + +def wait_for_session_route(session_id: str, timeout_seconds: int = 6) -> bool: + target = f"{TRAEFIK_INTERNAL_URL}/s/{session_id}/" + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get( + target, + headers={"Host": PUBLIC_HOST}, + allow_redirects=False, + timeout=1.5, + ) + if resp.status_code != 404: + return True + except Exception: + pass + time.sleep(0.3) + return False + + +def route_ready(path: str) -> bool: + bases = [TRAEFIK_INTERNAL_URL] + if TRAEFIK_INTERNAL_URL.startswith("http://"): + bases.append("https://" + TRAEFIK_INTERNAL_URL[len("http://"):]) + for base in bases: + try: + verify = not base.startswith("https://") + resp = requests.get( + f"{base}{path}", + headers={"Host": PUBLIC_HOST}, + allow_redirects=False, + timeout=1.5, + verify=verify, + ) + if resp.status_code != 404: + return True + except Exception: + continue + return False + + +def container_running(container_id: Optional[str]) -> bool: + if not container_id: + return False + if ( + container_id.startswith("POOL:") + or container_id.startswith("POOLIDX:") + or container_id.startswith("WEBPOOLIDX:") + ): + return True + if container_id.startswith("RDPSLOT:"): + try: + slot_id = int(container_id.split(":", 1)[1]) + db = SessionLocal() + try: + slot = db.get(RdpSlot, slot_id) + if not slot or not slot.container_name: + return False + c = docker_client().containers.get(slot.container_name) + return c.status == "running" + finally: + db.close() + except Exception: + return False + try: + c = docker_client().containers.get(container_id) + return c.status == "running" + except Exception: + return False + + +def _rdp_slot_container_name(service_slug: str, slot_id: int) -> str: + return f"portal-rdpslot-{service_slug}-{slot_id}" + + +def start_rdp_slot_container(slot: RdpSlot, service: Service) -> str: + d = docker_client() + name = _rdp_slot_container_name(service.slug, slot.id) + slot_id = slot.id + path = f"/rdp/{slot_id}/" + router = f"rdpslot-{slot_id}" + labels = { + "traefik.enable": "true", + "traefik.docker.network": "portal_net", + f"traefik.http.routers.{router}.rule": f"PathPrefix(`{path}`)", + f"traefik.http.routers.{router}.entrypoints": "websecure", + f"traefik.http.routers.{router}.tls": "true", + f"traefik.http.routers.{router}.priority": "10000", + f"traefik.http.routers.{router}.middlewares": f"{router}-strip", + f"traefik.http.middlewares.{router}-strip.stripprefix.prefixes": path[:-1], + f"traefik.http.services.{router}.loadbalancer.server.port": "6080", + "portal.rdpslot": "1", + "portal.rdpslot.id": str(slot_id), + "portal.service.slug": service.slug, + } + cfg = parse_rdp_target(service.target) + env = { + "SESSION_ID": f"rdpslot-{slot_id}", + "IDLE_TIMEOUT": "86400", + "ENABLE_HEARTBEAT": "0", + "RDP_HOST": cfg["host"], + "RDP_PORT": cfg["port"], + "X11VNC_FLAGS": X11VNC_FLAGS, + } + if slot.rdp_username: + env["RDP_USER"] = slot.rdp_username + if slot.rdp_password: + env["RDP_PASSWORD"] = slot.rdp_password + if cfg.get("domain"): + env["RDP_DOMAIN"] = cfg["domain"] + if cfg.get("security"): + env["RDP_SECURITY"] = cfg["security"] + + try: + existing = d.containers.get(name) + existing.stop(timeout=5) + existing.remove(force=True) + except docker.errors.NotFound: + pass + except Exception: + logger.exception("rdp_slot_container_cleanup_failed slot_id=%s", slot_id) + + container = d.containers.run( + "portal-rdp-proxy:latest", + name=name, + detach=True, + restart_policy={"Name": "unless-stopped"}, + network="portal_net", + labels=labels, + environment=env, + ) + logger.info("rdp_slot_container_started slot_id=%s name=%s", slot_id, name) + return container.name + + +def stop_rdp_slot_container(container_name: str) -> None: + if not container_name: + return + try: + d = docker_client() + c = d.containers.get(container_name) + c.stop(timeout=5) + c.remove(force=True) + logger.info("rdp_slot_container_stopped name=%s", container_name) + except docker.errors.NotFound: + pass + except Exception: + logger.exception("rdp_slot_container_stop_failed container=%s", container_name) + + +def _restart_rdp_slot_bg(slot_id: int) -> None: + db = SessionLocal() + try: + slot = db.get(RdpSlot, slot_id) + if not slot or not slot.container_name: + return + service = db.get(Service, slot.service_id) + if not service: + return + try: + d = docker_client() + c = d.containers.get(slot.container_name) + c.restart(timeout=10) + logger.info("rdp_slot_container_restarted slot_id=%s", slot_id) + except docker.errors.NotFound: + start_rdp_slot_container(slot, service) + except Exception: + logger.exception("rdp_slot_container_restart_failed slot_id=%s", slot_id) + finally: + db.close() + + +def stop_runtime_container(container_id: Optional[str]) -> None: + if not container_id: + return + try: + d = docker_client() + c = d.containers.get(container_id) + c.stop(timeout=5) + except Exception: + logger.exception("session_container_stop_failed container_id=%s", container_id) + + +def terminate_session_record( + db: Session, + sess: SessionModel, + new_status: SessionStatus = SessionStatus.TERMINATED, + *, + stop_container: bool = True, +) -> None: + if not sess or sess.status != SessionStatus.ACTIVE: + return + old_status = sess.status + cid = sess.container_id or "" + if stop_container and cid and not cid.startswith(("POOL:", "POOLIDX:", "WEBPOOLIDX:", "RDPSLOT:")): + stop_runtime_container(cid) + if cid.startswith("RDPSLOT:"): + try: + slot_id = int(cid.split(":", 1)[1]) + threading.Thread(target=_restart_rdp_slot_bg, args=(slot_id,), daemon=True).start() + except Exception: + logger.exception("rdp_slot_restart_schedule_failed cid=%s", cid) + sess.status = new_status + sess.last_access_at = now_utc() + log_event( + "session_closed", + level=logging.INFO, + session_id=sess.id, + user_id=sess.user_id, + service_id=sess.service_id, + container_id=cid, + old_status=old_status.value if isinstance(old_status, SessionStatus) else str(old_status), + new_status=new_status.value, + reason=session_closed_reason(sess, db), + stop_container=stop_container, + ) + + +def ensure_schema_compatibility() -> None: + # PostgreSQL requires enum value addition to be committed before usage in constraints. + with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + conn.execute( + text( + """ + DO $$ + BEGIN + BEGIN + ALTER TYPE servicetype ADD VALUE IF NOT EXISTS 'RDP'; + EXCEPTION WHEN undefined_object THEN + NULL; + END; + END $$; + """ + ) + ) + conn.execute( + text( + """ + DO $$ + BEGIN + BEGIN + ALTER TYPE sessionstatus ADD VALUE IF NOT EXISTS 'ROTATED'; + EXCEPTION WHEN undefined_object THEN + NULL; + END; + END $$; + """ + ) + ) + + with engine.begin() as conn: + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS warm_pool_size INTEGER NOT NULL DEFAULT 0")) + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS comment TEXT NOT NULL DEFAULT ''")) + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_login VARCHAR(256) NOT NULL DEFAULT ''")) + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_password VARCHAR(256) NOT NULL DEFAULT ''")) + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS svc_cred_hint TEXT NOT NULL DEFAULT ''")) + conn.execute(text("ALTER TABLE services ADD COLUMN IF NOT EXISTS icon_path TEXT NOT NULL DEFAULT ''")) + conn.execute( + text( + """ + CREATE TABLE IF NOT EXISTS categories ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL UNIQUE, + slug VARCHAR(64) NOT NULL UNIQUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + ) + """ + ) + ) + conn.execute( + text( + """ + CREATE TABLE IF NOT EXISTS service_categories ( + id SERIAL PRIMARY KEY, + service_id INT NOT NULL REFERENCES services(id) ON DELETE CASCADE, + category_id INT NOT NULL REFERENCES categories(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (service_id, category_id) + ) + """ + ) + ) + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_service_categories_service_id ON service_categories(service_id)")) + conn.execute(text("CREATE INDEX IF NOT EXISTS idx_service_categories_category_id ON service_categories(category_id)")) + # Handle installs where service type is VARCHAR + CHECK. + conn.execute( + text( + """ + DO $$ + DECLARE c record; + BEGIN + FOR c IN + SELECT conname + FROM pg_constraint + WHERE conrelid = 'services'::regclass + AND contype = 'c' + AND pg_get_constraintdef(oid) ILIKE '%type%' + LOOP + EXECUTE format('ALTER TABLE services DROP CONSTRAINT %I', c.conname); + END LOOP; + ALTER TABLE services + ADD CONSTRAINT services_type_check + CHECK (type IN ('WEB','VNC','RDP')); + EXCEPTION WHEN duplicate_object THEN + NULL; + END $$; + """ + ) + ) + + +def desired_pool_size(service: Service) -> int: + if not service.active: + return 0 + if service.type == ServiceType.RDP and not service_uses_universal_pool(service): + # RDP runs on-demand per user session; no prewarmed pool. + return 0 + if service_uses_universal_pool(service): + return UNIVERSAL_POOL_SIZE + return service.warm_pool_size if service.warm_pool_size and service.warm_pool_size > 0 else PREWARM_POOL_SIZE + + +def get_warm_containers_for_service(service: Service) -> list: + prefix = f"portal-warm-{service.slug}-" + try: + d = docker_client() + containers = [] + for c in d.containers.list(all=True, filters={"name": prefix}): + if c.name.startswith(prefix): + containers.append(c) + return containers + except Exception: + logger.exception("pool_status_failed service=%s", service.slug) + return [] + + +def get_pool_status_for_service(service: Service) -> dict: + if service.type == ServiceType.WEB: + return get_web_pool_status() + if service.type == ServiceType.RDP and not service_uses_universal_pool(service): + return {"desired": 0, "running": 0, "total": 0, "names": [], "health": "n/a"} + if service_uses_universal_pool(service): + return get_universal_pool_status() + desired = desired_pool_size(service) + containers = get_warm_containers_for_service(service) + running = sum(1 for c in containers if c.status == "running") + states = [(c.attrs.get("State") or {}).get("Status", c.status) for c in containers] + has_bad = any(s in {"exited", "dead"} for s in states) + total = len(containers) + if running == 0: + health = "down" + elif running >= min(desired, 1) and not has_bad: + health = "ok" + else: + health = "degraded" + return { + "desired": desired, + "running": running, + "total": total, + "names": sorted(c.name for c in containers), + "health": health, + } + + +def get_pool_detailed_status(service: Service) -> dict: + if service.type == ServiceType.WEB: + d = docker_client() + pool = get_web_pool_status() + details = [] + for i in range(max(0, pool["desired"])): + name = web_pool_container_name(i) + try: + c = d.containers.get(name) + except Exception: + continue + attrs = c.attrs or {} + state = (attrs.get("State") or {}).get("Status", c.status) + details.append( + { + "name": c.name, + "status": c.status, + "state": state, + "created": attrs.get("Created", ""), + "image": c.image.tags[0] if c.image.tags else "", + "labels_ok": True, + } + ) + return { + "service_id": service.id, + "slug": service.slug, + "type": service.type.value, + "desired": pool["desired"], + "running": pool["running"], + "total": pool["total"], + "health": pool["health"], + "containers": details, + "updated_at": now_utc().isoformat(), + } + if service_uses_universal_pool(service): + d = docker_client() + pool = get_universal_pool_status() + details = [] + for i in range(max(0, UNIVERSAL_POOL_SIZE)): + name = universal_container_name(i) + try: + c = d.containers.get(name) + except Exception: + continue + attrs = c.attrs or {} + state = (attrs.get("State") or {}).get("Status", c.status) + details.append( + { + "name": c.name, + "status": c.status, + "state": state, + "created": attrs.get("Created", ""), + "image": c.image.tags[0] if c.image.tags else "", + "labels_ok": True, + } + ) + return { + "service_id": service.id, + "slug": service.slug, + "type": service.type.value, + "desired": pool["desired"], + "running": pool["running"], + "total": pool["total"], + "health": pool["health"], + "containers": details, + "updated_at": now_utc().isoformat(), + } + containers = get_warm_containers_for_service(service) + pool = get_pool_status_for_service(service) + details = [] + for c in sorted(containers, key=lambda x: x.name): + attrs = c.attrs or {} + state = (attrs.get("State") or {}).get("Status", c.status) + created = attrs.get("Created", "") + labels = attrs.get("Config", {}).get("Labels", {}) or {} + labels_ok = ( + labels.get("portal.warm") == "1" + and labels.get("portal.service.slug") == service.slug + and labels.get("portal.service.type") == service.type.value + ) + details.append( + { + "name": c.name, + "status": c.status, + "state": state, + "created": created, + "image": c.image.tags[0] if c.image.tags else "", + "labels_ok": labels_ok, + } + ) + return { + "service_id": service.id, + "slug": service.slug, + "type": service.type.value, + "desired": pool["desired"], + "running": pool["running"], + "total": pool["total"], + "health": pool["health"], + "containers": details, + "updated_at": now_utc().isoformat(), + } + + +def get_active_sessions_count(db: Session, service_id: int) -> int: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = select(SessionModel).where( + SessionModel.service_id == service_id, + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.last_access_at >= cutoff, + ) + sessions = db.scalars(q).all() + # Avoid inflated stats when pooled slot sessions were duplicated by race: + # for pooled sessions, occupancy is unique container_id. + pooled = [s for s in sessions if (s.container_id or "").startswith(("WEBPOOLIDX:", "POOLIDX:", "POOL:"))] + direct = [s for s in sessions if s not in pooled] + unique_pooled = len({s.container_id for s in pooled if s.container_id}) + return unique_pooled + len(direct) + + +def find_active_session_for_service(db: Session, service_id: int) -> Optional[SessionModel]: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = ( + select(SessionModel) + .where( + SessionModel.service_id == service_id, + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.last_access_at >= cutoff, + ) + .order_by(SessionModel.created_at.desc()) + ) + return db.scalars(q).first() + + +def find_active_session_for_user_service(db: Session, user_id: int, service_id: int) -> Optional[SessionModel]: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + q = ( + select(SessionModel) + .where( + SessionModel.user_id == user_id, + SessionModel.service_id == service_id, + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.last_access_at >= cutoff, + ) + .order_by(SessionModel.created_at.desc()) + ) + return db.scalars(q).first() + + +class LockTimeoutError(Exception): + pass + + +def allocator_lock(db: Session, lock_id: int, timeout_seconds: Optional[float] = None, poll_seconds: float = 0.05): + class _LockCtx: + def __enter__(self_nonlocal): + self_nonlocal._acquired = False + if timeout_seconds is None: + db.execute(text("SELECT pg_advisory_xact_lock(:lid)"), {"lid": lock_id}) + self_nonlocal._acquired = True + return self_nonlocal + + deadline = time.monotonic() + max(0.0, timeout_seconds) + while time.monotonic() <= deadline: + got = db.execute(text("SELECT pg_try_advisory_xact_lock(:lid)"), {"lid": lock_id}).scalar() + if got: + self_nonlocal._acquired = True + return self_nonlocal + time.sleep(max(0.01, poll_seconds)) + raise LockTimeoutError(f"advisory lock timeout lock_id={lock_id} timeout={timeout_seconds}") + + return self_nonlocal + + def __exit__(self_nonlocal, exc_type, exc, tb): + return False + + return _LockCtx() + + +def terminate_active_slot_sessions(db: Session, container_id: str) -> None: + if not container_id: + return + db.execute( + text( + """ + UPDATE sessions + SET status = 'TERMINATED' + WHERE container_id = :cid + AND status = 'ACTIVE' + """ + ), + {"cid": container_id}, + ) + + +def session_redirect_url(sess: SessionModel) -> str: + cid = sess.container_id or "" + if cid.startswith("POOL:") or cid.startswith("POOLIDX:") or cid.startswith("WEBPOOLIDX:") or cid.startswith("RDPSLOT:"): + return f"/s/{sess.id}/view" + return f"/s/{sess.id}/" + + +def open_warm_web_url(service: Service, target_url: str) -> None: + if service_uses_universal_pool(service): + return + if service.type != ServiceType.WEB: + return + target_url = normalize_web_target(target_url) + try: + d = docker_client() + containers = d.containers.list( + filters={ + "label": [ + "portal.warm=1", + f"portal.service.slug={service.slug}", + "portal.service.type=WEB", + ] + } + ) + for c in containers: + try: + resp = requests.post( + f"http://{c.name}:7000/open", + json={"url": target_url}, + timeout=2, + ) + resp.raise_for_status() + logger.info("warm_web_open_ok service=%s container=%s url=%s", service.slug, c.name, target_url) + except Exception: + logger.exception("warm_web_open_failed service=%s container=%s", service.slug, c.name) + except Exception: + logger.exception("warm_web_open_dispatch_failed service=%s", service.slug) + + + + +def service_uses_universal_pool(service) -> bool: + return UNIVERSAL_POOL_SIZE > 0 and service.type == ServiceType.RDP + + +def universal_container_name(slot: int) -> str: + return f"portal-universal-{slot}" + + +def web_pool_container_name(slot: int) -> str: + return f"portal-webpool-{slot}" diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..491da5e --- /dev/null +++ b/app/utils.py @@ -0,0 +1,181 @@ +import datetime as dt +import json +import logging +import contextvars +from pathlib import Path +from typing import Optional +from urllib.parse import parse_qs, unquote, urlparse + +import mistune +from fastapi import HTTPException, UploadFile +from markupsafe import Markup +from sqlalchemy import select +from sqlalchemy.orm import Session + +from config import ( + ICON_UPLOAD_MAX_BYTES, ICON_UPLOAD_TYPES, MAX_ACTIVE_SERVICES_PER_USER, + SERVICE_ICONS_DIR, SESSION_IDLE_SECONDS, +) +from models import AuditLog, Category, ServiceCategory, SessionModel, SessionStatus + + +logger = logging.getLogger("portal") +request_id_ctx = contextvars.ContextVar("request_id", default="-") + + +def _normalize_log_value(value): + if isinstance(value, (str, int, float, bool)) or value is None: + return value + if isinstance(value, dt.datetime): + return value.isoformat() + return str(value) + + +def log_event(event: str, level: int = logging.INFO, **fields) -> None: + payload = {"event": event, "req_id": request_id_ctx.get()} + for key, value in fields.items(): + payload[key] = _normalize_log_value(value) + logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":"))) + + +def now_utc() -> dt.datetime: + return dt.datetime.now(dt.timezone.utc) + + +def session_closed_reason(sess: SessionModel, db: Session) -> str: + if not sess: + return "idle" + if sess.status == SessionStatus.EXPIRED: + return "idle" + if sess.status == SessionStatus.ROTATED: + return "limit" + if sess.status == SessionStatus.TERMINATED: + cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) + active_rows = db.scalars( + select(SessionModel).where( + SessionModel.user_id == sess.user_id, + SessionModel.status == SessionStatus.ACTIVE, + SessionModel.last_access_at >= cutoff, + ) + ).all() + active_service_ids = {row.service_id for row in active_rows} + if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids: + return "limit" + return "idle" + + +def normalize_web_target(url: str) -> str: + raw = (url or "").strip() + if not raw: + return raw + if raw.startswith(("http://", "https://")): + return raw + return f"http://{raw}" + + +_md = mistune.create_markdown( + escape=True, + plugins=["strikethrough", "table", "task_lists"], +) + + +def format_service_comment(raw_text: str) -> Markup: + raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip() + if not raw: + return Markup("") + return Markup(_md(raw)) + + +def parse_rdp_target(target: str) -> dict: + raw = (target or "").strip() + if not raw: + raise HTTPException(status_code=400, detail="Empty RDP target") + + parsed = urlparse(raw if "://" in raw else f"//{raw}") + host = parsed.hostname + if not host: + raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port") + port = parsed.port or 3389 + + username = unquote(parsed.username) if parsed.username else "" + password = unquote(parsed.password) if parsed.password else "" + + query = parse_qs(parsed.query or "") + if not username: + username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip() + if not password: + password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip() + + domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip() + security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower() + if security and security not in {"nla", "tls", "rdp"}: + raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp") + + return { + "host": host, + "port": str(port), + "user": username, + "password": password, + "domain": domain, + "security": security, + } + + +def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None: + normalized = sorted({int(x) for x in (category_ids or [])}) + if normalized: + existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all()) + missing = sorted(set(normalized) - existing_ids) + if missing: + raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}") + + existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all() + current = {row.category_id: row for row in existing_links} + wanted = set(normalized) + + for cat_id in wanted: + if cat_id not in current: + db.add(ServiceCategory(service_id=service_id, category_id=cat_id)) + for cat_id, row in current.items(): + if cat_id not in wanted: + db.delete(row) + + +def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None: + db.add(AuditLog(user_id=user_id, action=action, details=details)) + db.commit() + + +def ensure_icons_dir() -> None: + SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True) + + +def remove_icon_file(icon_path: str) -> None: + if not icon_path or not icon_path.startswith("/static/service-icons/"): + return + filename = icon_path.rsplit("/", 1)[-1] + candidate = SERVICE_ICONS_DIR / filename + try: + candidate.unlink(missing_ok=True) + except Exception: + logger.exception("icon_delete_failed path=%s", candidate) + + +async def store_service_icon(service, upload: UploadFile) -> str: + ensure_icons_dir() + content_type = (upload.content_type or "").lower().strip() + ext = ICON_UPLOAD_TYPES.get(content_type) + if not ext: + raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP") + + payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1) + if len(payload) > ICON_UPLOAD_MAX_BYTES: + raise HTTPException(status_code=400, detail="File too large. Max 2MB") + if not payload: + raise HTTPException(status_code=400, detail="Empty file") + + stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S") + filename = f"svc_{service.id}_{stamp}.{ext}" + target = SERVICE_ICONS_DIR / filename + target.write_bytes(payload) + return f"/static/service-icons/{filename}"