refactor: split main.py into modules (config, database, models, utils, auth, runtime, maintenance)
main.py was ~3000 lines with models, routes, Docker ops, maintenance all mixed. Split into 7 focused modules: - config.py: env vars and constants - database.py: SQLAlchemy engine, SessionLocal, Base, get_db - models.py: ORM models and enums - utils.py: logging, formatting, icon handling, misc helpers - auth.py: password hashing, cookies, CSRF, user dependency - runtime.py: all Docker operations, pool management, session lifecycle - maintenance.py: cleanup loop, schema bootstrap, startup logic - main.py: FastAPI app, middleware, all route handlers only
This commit is contained in:
+181
@@ -0,0 +1,181 @@
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
import contextvars
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import mistune
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from markupsafe import Markup
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import (
|
||||
ICON_UPLOAD_MAX_BYTES, ICON_UPLOAD_TYPES, MAX_ACTIVE_SERVICES_PER_USER,
|
||||
SERVICE_ICONS_DIR, SESSION_IDLE_SECONDS,
|
||||
)
|
||||
from models import AuditLog, Category, ServiceCategory, SessionModel, SessionStatus
|
||||
|
||||
|
||||
logger = logging.getLogger("portal")
|
||||
request_id_ctx = contextvars.ContextVar("request_id", default="-")
|
||||
|
||||
|
||||
def _normalize_log_value(value):
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, dt.datetime):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
|
||||
def log_event(event: str, level: int = logging.INFO, **fields) -> None:
|
||||
payload = {"event": event, "req_id": request_id_ctx.get()}
|
||||
for key, value in fields.items():
|
||||
payload[key] = _normalize_log_value(value)
|
||||
logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":")))
|
||||
|
||||
|
||||
def now_utc() -> dt.datetime:
|
||||
return dt.datetime.now(dt.timezone.utc)
|
||||
|
||||
|
||||
def session_closed_reason(sess: SessionModel, db: Session) -> str:
|
||||
if not sess:
|
||||
return "idle"
|
||||
if sess.status == SessionStatus.EXPIRED:
|
||||
return "idle"
|
||||
if sess.status == SessionStatus.ROTATED:
|
||||
return "limit"
|
||||
if sess.status == SessionStatus.TERMINATED:
|
||||
cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS)
|
||||
active_rows = db.scalars(
|
||||
select(SessionModel).where(
|
||||
SessionModel.user_id == sess.user_id,
|
||||
SessionModel.status == SessionStatus.ACTIVE,
|
||||
SessionModel.last_access_at >= cutoff,
|
||||
)
|
||||
).all()
|
||||
active_service_ids = {row.service_id for row in active_rows}
|
||||
if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids:
|
||||
return "limit"
|
||||
return "idle"
|
||||
|
||||
|
||||
def normalize_web_target(url: str) -> str:
|
||||
raw = (url or "").strip()
|
||||
if not raw:
|
||||
return raw
|
||||
if raw.startswith(("http://", "https://")):
|
||||
return raw
|
||||
return f"http://{raw}"
|
||||
|
||||
|
||||
_md = mistune.create_markdown(
|
||||
escape=True,
|
||||
plugins=["strikethrough", "table", "task_lists"],
|
||||
)
|
||||
|
||||
|
||||
def format_service_comment(raw_text: str) -> Markup:
|
||||
raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip()
|
||||
if not raw:
|
||||
return Markup("")
|
||||
return Markup(_md(raw))
|
||||
|
||||
|
||||
def parse_rdp_target(target: str) -> dict:
|
||||
raw = (target or "").strip()
|
||||
if not raw:
|
||||
raise HTTPException(status_code=400, detail="Empty RDP target")
|
||||
|
||||
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port")
|
||||
port = parsed.port or 3389
|
||||
|
||||
username = unquote(parsed.username) if parsed.username else ""
|
||||
password = unquote(parsed.password) if parsed.password else ""
|
||||
|
||||
query = parse_qs(parsed.query or "")
|
||||
if not username:
|
||||
username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip()
|
||||
if not password:
|
||||
password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip()
|
||||
|
||||
domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip()
|
||||
security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower()
|
||||
if security and security not in {"nla", "tls", "rdp"}:
|
||||
raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp")
|
||||
|
||||
return {
|
||||
"host": host,
|
||||
"port": str(port),
|
||||
"user": username,
|
||||
"password": password,
|
||||
"domain": domain,
|
||||
"security": security,
|
||||
}
|
||||
|
||||
|
||||
def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None:
|
||||
normalized = sorted({int(x) for x in (category_ids or [])})
|
||||
if normalized:
|
||||
existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all())
|
||||
missing = sorted(set(normalized) - existing_ids)
|
||||
if missing:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}")
|
||||
|
||||
existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all()
|
||||
current = {row.category_id: row for row in existing_links}
|
||||
wanted = set(normalized)
|
||||
|
||||
for cat_id in wanted:
|
||||
if cat_id not in current:
|
||||
db.add(ServiceCategory(service_id=service_id, category_id=cat_id))
|
||||
for cat_id, row in current.items():
|
||||
if cat_id not in wanted:
|
||||
db.delete(row)
|
||||
|
||||
|
||||
def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None:
|
||||
db.add(AuditLog(user_id=user_id, action=action, details=details))
|
||||
db.commit()
|
||||
|
||||
|
||||
def ensure_icons_dir() -> None:
|
||||
SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def remove_icon_file(icon_path: str) -> None:
|
||||
if not icon_path or not icon_path.startswith("/static/service-icons/"):
|
||||
return
|
||||
filename = icon_path.rsplit("/", 1)[-1]
|
||||
candidate = SERVICE_ICONS_DIR / filename
|
||||
try:
|
||||
candidate.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
logger.exception("icon_delete_failed path=%s", candidate)
|
||||
|
||||
|
||||
async def store_service_icon(service, upload: UploadFile) -> str:
|
||||
ensure_icons_dir()
|
||||
content_type = (upload.content_type or "").lower().strip()
|
||||
ext = ICON_UPLOAD_TYPES.get(content_type)
|
||||
if not ext:
|
||||
raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP")
|
||||
|
||||
payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1)
|
||||
if len(payload) > ICON_UPLOAD_MAX_BYTES:
|
||||
raise HTTPException(status_code=400, detail="File too large. Max 2MB")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"svc_{service.id}_{stamp}.{ext}"
|
||||
target = SERVICE_ICONS_DIR / filename
|
||||
target.write_bytes(payload)
|
||||
return f"/static/service-icons/{filename}"
|
||||
Reference in New Issue
Block a user