import datetime as dt import json import logging import contextvars from pathlib import Path from typing import Optional from urllib.parse import parse_qs, unquote, urlparse import mistune from fastapi import HTTPException, UploadFile from markupsafe import Markup from sqlalchemy import select from sqlalchemy.orm import Session from config import ( ICON_UPLOAD_MAX_BYTES, ICON_UPLOAD_TYPES, MAX_ACTIVE_SERVICES_PER_USER, SERVICE_ICONS_DIR, SESSION_IDLE_SECONDS, ) from models import AuditLog, Category, ServiceCategory, SessionModel, SessionStatus logger = logging.getLogger("portal") request_id_ctx = contextvars.ContextVar("request_id", default="-") def _normalize_log_value(value): if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, dt.datetime): return value.isoformat() return str(value) def log_event(event: str, level: int = logging.INFO, **fields) -> None: payload = {"event": event, "req_id": request_id_ctx.get()} for key, value in fields.items(): payload[key] = _normalize_log_value(value) logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":"))) def now_utc() -> dt.datetime: return dt.datetime.now(dt.timezone.utc) def session_closed_reason(sess: SessionModel, db: Session) -> str: if not sess: return "idle" if sess.status == SessionStatus.EXPIRED: return "idle" if sess.status == SessionStatus.ROTATED: return "limit" if sess.status == SessionStatus.TERMINATED: cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS) active_rows = db.scalars( select(SessionModel).where( SessionModel.user_id == sess.user_id, SessionModel.status == SessionStatus.ACTIVE, SessionModel.last_access_at >= cutoff, ) ).all() active_service_ids = {row.service_id for row in active_rows} if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids: return "limit" return "idle" def normalize_web_target(url: str) -> str: raw = (url or "").strip() if not raw: return raw if raw.startswith(("http://", "https://")): return raw return f"http://{raw}" _md = mistune.create_markdown( escape=True, plugins=["strikethrough", "table", "task_lists"], ) def format_service_comment(raw_text: str) -> Markup: raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip() if not raw: return Markup("") return Markup(_md(raw)) def parse_rdp_target(target: str) -> dict: raw = (target or "").strip() if not raw: raise HTTPException(status_code=400, detail="Empty RDP target") parsed = urlparse(raw if "://" in raw else f"//{raw}") host = parsed.hostname if not host: raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port") port = parsed.port or 3389 username = unquote(parsed.username) if parsed.username else "" password = unquote(parsed.password) if parsed.password else "" query = parse_qs(parsed.query or "") if not username: username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip() if not password: password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip() domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip() security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower() if security and security not in {"nla", "tls", "rdp"}: raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp") return { "host": host, "port": str(port), "user": username, "password": password, "domain": domain, "security": security, } def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None: normalized = sorted({int(x) for x in (category_ids or [])}) if normalized: existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all()) missing = sorted(set(normalized) - existing_ids) if missing: raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}") existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all() current = {row.category_id: row for row in existing_links} wanted = set(normalized) for cat_id in wanted: if cat_id not in current: db.add(ServiceCategory(service_id=service_id, category_id=cat_id)) for cat_id, row in current.items(): if cat_id not in wanted: db.delete(row) def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None: db.add(AuditLog(user_id=user_id, action=action, details=details)) db.commit() def ensure_icons_dir() -> None: SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True) def remove_icon_file(icon_path: str) -> None: if not icon_path or not icon_path.startswith("/static/service-icons/"): return filename = icon_path.rsplit("/", 1)[-1] candidate = SERVICE_ICONS_DIR / filename try: candidate.unlink(missing_ok=True) except Exception: logger.exception("icon_delete_failed path=%s", candidate) async def store_service_icon(service, upload: UploadFile) -> str: ensure_icons_dir() content_type = (upload.content_type or "").lower().strip() ext = ICON_UPLOAD_TYPES.get(content_type) if not ext: raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP") payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1) if len(payload) > ICON_UPLOAD_MAX_BYTES: raise HTTPException(status_code=400, detail="File too large. Max 2MB") if not payload: raise HTTPException(status_code=400, detail="Empty file") stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S") filename = f"svc_{service.id}_{stamp}.{ext}" target = SERVICE_ICONS_DIR / filename target.write_bytes(payload) return f"/static/service-icons/{filename}"