Files
ruslan c8c77048c7 refactor: split main.py into modules (config, database, models, utils, auth, runtime, maintenance)
main.py was ~3000 lines with models, routes, Docker ops, maintenance all mixed.
Split into 7 focused modules:
- config.py: env vars and constants
- database.py: SQLAlchemy engine, SessionLocal, Base, get_db
- models.py: ORM models and enums
- utils.py: logging, formatting, icon handling, misc helpers
- auth.py: password hashing, cookies, CSRF, user dependency
- runtime.py: all Docker operations, pool management, session lifecycle
- maintenance.py: cleanup loop, schema bootstrap, startup logic
- main.py: FastAPI app, middleware, all route handlers only
2026-05-01 09:40:06 +00:00

182 lines
6.2 KiB
Python

import datetime as dt
import json
import logging
import contextvars
from pathlib import Path
from typing import Optional
from urllib.parse import parse_qs, unquote, urlparse
import mistune
from fastapi import HTTPException, UploadFile
from markupsafe import Markup
from sqlalchemy import select
from sqlalchemy.orm import Session
from config import (
ICON_UPLOAD_MAX_BYTES, ICON_UPLOAD_TYPES, MAX_ACTIVE_SERVICES_PER_USER,
SERVICE_ICONS_DIR, SESSION_IDLE_SECONDS,
)
from models import AuditLog, Category, ServiceCategory, SessionModel, SessionStatus
logger = logging.getLogger("portal")
request_id_ctx = contextvars.ContextVar("request_id", default="-")
def _normalize_log_value(value):
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, dt.datetime):
return value.isoformat()
return str(value)
def log_event(event: str, level: int = logging.INFO, **fields) -> None:
payload = {"event": event, "req_id": request_id_ctx.get()}
for key, value in fields.items():
payload[key] = _normalize_log_value(value)
logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":")))
def now_utc() -> dt.datetime:
return dt.datetime.now(dt.timezone.utc)
def session_closed_reason(sess: SessionModel, db: Session) -> str:
if not sess:
return "idle"
if sess.status == SessionStatus.EXPIRED:
return "idle"
if sess.status == SessionStatus.ROTATED:
return "limit"
if sess.status == SessionStatus.TERMINATED:
cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS)
active_rows = db.scalars(
select(SessionModel).where(
SessionModel.user_id == sess.user_id,
SessionModel.status == SessionStatus.ACTIVE,
SessionModel.last_access_at >= cutoff,
)
).all()
active_service_ids = {row.service_id for row in active_rows}
if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids:
return "limit"
return "idle"
def normalize_web_target(url: str) -> str:
raw = (url or "").strip()
if not raw:
return raw
if raw.startswith(("http://", "https://")):
return raw
return f"http://{raw}"
_md = mistune.create_markdown(
escape=True,
plugins=["strikethrough", "table", "task_lists"],
)
def format_service_comment(raw_text: str) -> Markup:
raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip()
if not raw:
return Markup("")
return Markup(_md(raw))
def parse_rdp_target(target: str) -> dict:
raw = (target or "").strip()
if not raw:
raise HTTPException(status_code=400, detail="Empty RDP target")
parsed = urlparse(raw if "://" in raw else f"//{raw}")
host = parsed.hostname
if not host:
raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port")
port = parsed.port or 3389
username = unquote(parsed.username) if parsed.username else ""
password = unquote(parsed.password) if parsed.password else ""
query = parse_qs(parsed.query or "")
if not username:
username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip()
if not password:
password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip()
domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip()
security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower()
if security and security not in {"nla", "tls", "rdp"}:
raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp")
return {
"host": host,
"port": str(port),
"user": username,
"password": password,
"domain": domain,
"security": security,
}
def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None:
normalized = sorted({int(x) for x in (category_ids or [])})
if normalized:
existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all())
missing = sorted(set(normalized) - existing_ids)
if missing:
raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}")
existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all()
current = {row.category_id: row for row in existing_links}
wanted = set(normalized)
for cat_id in wanted:
if cat_id not in current:
db.add(ServiceCategory(service_id=service_id, category_id=cat_id))
for cat_id, row in current.items():
if cat_id not in wanted:
db.delete(row)
def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None:
db.add(AuditLog(user_id=user_id, action=action, details=details))
db.commit()
def ensure_icons_dir() -> None:
SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True)
def remove_icon_file(icon_path: str) -> None:
if not icon_path or not icon_path.startswith("/static/service-icons/"):
return
filename = icon_path.rsplit("/", 1)[-1]
candidate = SERVICE_ICONS_DIR / filename
try:
candidate.unlink(missing_ok=True)
except Exception:
logger.exception("icon_delete_failed path=%s", candidate)
async def store_service_icon(service, upload: UploadFile) -> str:
ensure_icons_dir()
content_type = (upload.content_type or "").lower().strip()
ext = ICON_UPLOAD_TYPES.get(content_type)
if not ext:
raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP")
payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1)
if len(payload) > ICON_UPLOAD_MAX_BYTES:
raise HTTPException(status_code=400, detail="File too large. Max 2MB")
if not payload:
raise HTTPException(status_code=400, detail="Empty file")
stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S")
filename = f"svc_{service.id}_{stamp}.{ext}"
target = SERVICE_ICONS_DIR / filename
target.write_bytes(payload)
return f"/static/service-icons/{filename}"