refactor: split main.py into modules (config, database, models, utils, auth, runtime, maintenance)
main.py was ~3000 lines with models, routes, Docker ops, maintenance all mixed. Split into 7 focused modules: - config.py: env vars and constants - database.py: SQLAlchemy engine, SessionLocal, Base, get_db - models.py: ORM models and enums - utils.py: logging, formatting, icon handling, misc helpers - auth.py: password hashing, cookies, CSRF, user dependency - runtime.py: all Docker operations, pool management, session lifecycle - maintenance.py: cleanup loop, schema bootstrap, startup logic - main.py: FastAPI app, middleware, all route handlers only
This commit is contained in:
+102
@@ -0,0 +1,102 @@
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from itsdangerous import BadSignature, URLSafeTimedSerializer
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import COOKIE_MAX_AGE, COOKIE_NAME, CSRF_COOKIE
|
||||
from database import get_db
|
||||
from models import User, UserServiceAccess
|
||||
from utils import now_utc
|
||||
from sqlalchemy import select
|
||||
|
||||
import os
|
||||
|
||||
_SIGNING_KEY = os.getenv("SIGNING_KEY", secrets.token_urlsafe(32))
|
||||
serializer = URLSafeTimedSerializer(_SIGNING_KEY, salt="portal-auth")
|
||||
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
return pwd_context.verify(password, password_hash)
|
||||
|
||||
|
||||
def user_is_valid(user: User) -> bool:
|
||||
return bool(user.active and user.expires_at > now_utc())
|
||||
|
||||
|
||||
def issue_auth_cookie(response: RedirectResponse, user: User) -> None:
|
||||
token = serializer.dumps({"user_id": user.id})
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def issue_csrf_cookie(response: RedirectResponse) -> str:
|
||||
token = secrets.token_urlsafe(24)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE,
|
||||
value=token,
|
||||
httponly=False,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
path="/",
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def get_current_user(request: Request, db: Session = Depends(get_db)) -> Optional[User]:
|
||||
raw = request.cookies.get(COOKIE_NAME)
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
payload = serializer.loads(raw, max_age=COOKIE_MAX_AGE)
|
||||
except BadSignature:
|
||||
return None
|
||||
user = db.get(User, int(payload["user_id"]))
|
||||
if not user or not user_is_valid(user):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def require_user(user: Optional[User] = Depends(get_current_user)) -> User:
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized")
|
||||
return user
|
||||
|
||||
|
||||
def require_admin(user: User = Depends(require_user)) -> User:
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin only")
|
||||
return user
|
||||
|
||||
|
||||
def validate_csrf(request: Request) -> None:
|
||||
cookie = request.cookies.get(CSRF_COOKIE)
|
||||
form_val = request.headers.get("X-CSRF-Token")
|
||||
if request.headers.get("content-type", "").startswith("application/x-www-form-urlencoded"):
|
||||
return
|
||||
if not cookie or not form_val or cookie != form_val:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF failed")
|
||||
|
||||
|
||||
def has_access(db: Session, user_id: int, service_id: int) -> bool:
|
||||
q = select(UserServiceAccess).where(
|
||||
UserServiceAccess.user_id == user_id,
|
||||
UserServiceAccess.service_id == service_id,
|
||||
)
|
||||
return db.scalar(q) is not None
|
||||
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql+psycopg2://portal:portal@db:5432/portal")
|
||||
COOKIE_NAME = "portal_auth"
|
||||
CSRF_COOKIE = "csrf_token"
|
||||
COOKIE_MAX_AGE = 8 * 60 * 60
|
||||
SESSION_IDLE_SECONDS = int(os.getenv("SESSION_IDLE_SECONDS", "7200"))
|
||||
PUBLIC_HOST = os.getenv("PUBLIC_HOST", "stend.4mont.ru")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_SLOW_REQUEST_MS = int(os.getenv("LOG_SLOW_REQUEST_MS", "2000"))
|
||||
GO_USER_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_USER_LOCK_TIMEOUT_SECONDS", "8.0"))
|
||||
GO_POOL_LOCK_TIMEOUT_SECONDS = float(os.getenv("GO_POOL_LOCK_TIMEOUT_SECONDS", "20.0"))
|
||||
POOL_DISPATCH_RETRIES = int(os.getenv("POOL_DISPATCH_RETRIES", "6"))
|
||||
POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS = float(os.getenv("POOL_DISPATCH_REQUEST_TIMEOUT_SECONDS", "2.0"))
|
||||
POOL_DISPATCH_SLEEP_SECONDS = float(os.getenv("POOL_DISPATCH_SLEEP_SECONDS", "0.3"))
|
||||
TRAEFIK_INTERNAL_URL = os.getenv("TRAEFIK_INTERNAL_URL", "http://traefik")
|
||||
PREWARM_POOL_SIZE = int(os.getenv("PREWARM_POOL_SIZE", "2"))
|
||||
UNIVERSAL_POOL_SIZE = int(os.getenv("UNIVERSAL_POOL_SIZE", "0"))
|
||||
WEB_POOL_SIZE = int(os.getenv("WEB_POOL_SIZE", "20"))
|
||||
WEB_POOL_BUFFER = int(os.getenv("WEB_POOL_BUFFER", "2"))
|
||||
X11VNC_FLAGS = os.getenv("X11VNC_FLAGS", "-wait 5 -defer 5 -threads")
|
||||
MAX_ACTIVE_SERVICES_PER_USER = int(os.getenv("MAX_ACTIVE_SERVICES_PER_USER", "4"))
|
||||
WEB_RESOLUTION_MIN_WIDTH = int(os.getenv("WEB_RESOLUTION_MIN_WIDTH", "1024"))
|
||||
WEB_RESOLUTION_MIN_HEIGHT = int(os.getenv("WEB_RESOLUTION_MIN_HEIGHT", "720"))
|
||||
WEB_RESOLUTION_MAX_WIDTH = int(os.getenv("WEB_RESOLUTION_MAX_WIDTH", "3840"))
|
||||
WEB_RESOLUTION_MAX_HEIGHT = int(os.getenv("WEB_RESOLUTION_MAX_HEIGHT", "2160"))
|
||||
ENABLE_STARTUP_MAINTENANCE = os.getenv("ENABLE_STARTUP_MAINTENANCE", "1") == "1"
|
||||
ICON_UPLOAD_MAX_BYTES = 2 * 1024 * 1024
|
||||
ICON_UPLOAD_TYPES = {
|
||||
"image/png": "png",
|
||||
"image/jpeg": "jpg",
|
||||
"image/webp": "webp",
|
||||
}
|
||||
SERVICE_ICONS_DIR = Path("static/service-icons")
|
||||
@@ -0,0 +1,20 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
from config import DATABASE_URL
|
||||
|
||||
|
||||
engine = create_engine(DATABASE_URL, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
+42
-1700
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,196 @@
|
||||
import datetime as dt
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
import docker
|
||||
from sqlalchemy import select
|
||||
|
||||
from config import ENABLE_STARTUP_MAINTENANCE, SESSION_IDLE_SECONDS, WEB_POOL_SIZE
|
||||
from database import Base, SessionLocal, engine
|
||||
from models import RdpSlot, Service, ServiceType, SessionModel, SessionStatus, User
|
||||
from utils import ensure_icons_dir, now_utc
|
||||
from auth import hash_password
|
||||
from runtime import (
|
||||
_rdp_slot_container_name,
|
||||
_restart_rdp_slot_bg,
|
||||
docker_client,
|
||||
ensure_schema_compatibility,
|
||||
ensure_universal_pool,
|
||||
ensure_warm_pool,
|
||||
ensure_web_pool,
|
||||
start_rdp_slot_container,
|
||||
stop_runtime_container,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("portal")
|
||||
maintenance_lock_file = None
|
||||
|
||||
|
||||
def cleanup_loop():
|
||||
while True:
|
||||
time.sleep(60)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ensure_universal_pool()
|
||||
ensure_web_pool()
|
||||
for svc in db.scalars(
|
||||
select(Service).where(
|
||||
Service.active == True,
|
||||
Service.type.in_([ServiceType.WEB, ServiceType.RDP]),
|
||||
)
|
||||
).all():
|
||||
if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0:
|
||||
ensure_warm_pool(svc)
|
||||
cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS)
|
||||
q = select(SessionModel).where(
|
||||
SessionModel.status == SessionStatus.ACTIVE,
|
||||
SessionModel.last_access_at < cutoff,
|
||||
)
|
||||
stale = db.scalars(q).all()
|
||||
rdp_slots_to_restart: list[int] = []
|
||||
for sess in stale:
|
||||
cid = sess.container_id or ""
|
||||
if cid.startswith("RDPSLOT:"):
|
||||
try:
|
||||
rdp_slots_to_restart.append(int(cid.split(":", 1)[1]))
|
||||
except Exception:
|
||||
pass
|
||||
elif cid and not (
|
||||
cid.startswith("POOL:")
|
||||
or cid.startswith("POOLIDX:")
|
||||
or cid.startswith("WEBPOOLIDX:")
|
||||
):
|
||||
stop_runtime_container(cid)
|
||||
sess.status = SessionStatus.EXPIRED
|
||||
if stale:
|
||||
db.commit()
|
||||
for slot_id in rdp_slots_to_restart:
|
||||
threading.Thread(target=_restart_rdp_slot_bg, args=(slot_id,), daemon=True).start()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
logger.exception("cleanup_loop_failed")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def bootstrap_admin():
|
||||
admin_user = os.getenv("ADMIN_USERNAME", "admin")
|
||||
admin_password = os.getenv("ADMIN_PASSWORD", "change_me")
|
||||
ttl_days = int(os.getenv("ADMIN_TTL_DAYS", "3650"))
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
existing = db.scalar(select(User).where(User.username == admin_user))
|
||||
if not existing:
|
||||
db.add(
|
||||
User(
|
||||
username=admin_user,
|
||||
password_hash=hash_password(admin_password),
|
||||
active=True,
|
||||
is_admin=True,
|
||||
expires_at=now_utc() + dt.timedelta(days=ttl_days),
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def try_acquire_maintenance_leader() -> bool:
|
||||
global maintenance_lock_file
|
||||
if maintenance_lock_file is not None:
|
||||
return True
|
||||
lock_file = open("/tmp/portal-maintenance.lock", "w")
|
||||
try:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except BlockingIOError:
|
||||
lock_file.close()
|
||||
return False
|
||||
maintenance_lock_file = lock_file
|
||||
return True
|
||||
|
||||
|
||||
def run_maintenance_service() -> None:
|
||||
logger.info("maintenance_service_bootstrap_started")
|
||||
with open("/tmp/portal-schema.lock", "w") as lock_file:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
ensure_schema_compatibility()
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
ensure_icons_dir()
|
||||
bootstrap_admin()
|
||||
|
||||
maintenance_lock = open("/tmp/portal-maintenance.lock", "w")
|
||||
fcntl.flock(maintenance_lock.fileno(), fcntl.LOCK_EX)
|
||||
logger.info("maintenance_service_leader_acquired")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ensure_universal_pool()
|
||||
ensure_web_pool()
|
||||
for svc in db.scalars(
|
||||
select(Service).where(
|
||||
Service.active == True,
|
||||
Service.type.in_([ServiceType.WEB, ServiceType.RDP]),
|
||||
)
|
||||
).all():
|
||||
if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0:
|
||||
ensure_warm_pool(svc)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info("maintenance_service_loop_started")
|
||||
cleanup_loop()
|
||||
|
||||
|
||||
def on_startup() -> None:
|
||||
with open("/tmp/portal-schema.lock", "w") as lock_file:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
ensure_schema_compatibility()
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
ensure_icons_dir()
|
||||
bootstrap_admin()
|
||||
if not try_acquire_maintenance_leader():
|
||||
logger.info("maintenance_leader_skipped")
|
||||
return
|
||||
|
||||
if ENABLE_STARTUP_MAINTENANCE:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ensure_universal_pool()
|
||||
ensure_web_pool()
|
||||
for svc in db.scalars(
|
||||
select(Service).where(
|
||||
Service.active == True,
|
||||
Service.type.in_([ServiceType.WEB, ServiceType.RDP]),
|
||||
)
|
||||
).all():
|
||||
if svc.type == ServiceType.WEB and WEB_POOL_SIZE <= 0:
|
||||
ensure_warm_pool(svc)
|
||||
elif svc.type == ServiceType.RDP:
|
||||
slots = db.scalars(select(RdpSlot).where(RdpSlot.service_id == svc.id)).all()
|
||||
for slot in slots:
|
||||
try:
|
||||
cname = _rdp_slot_container_name(svc.slug, slot.id)
|
||||
try:
|
||||
c = docker_client().containers.get(cname)
|
||||
if c.status != "running":
|
||||
c.start()
|
||||
except docker.errors.NotFound:
|
||||
start_rdp_slot_container(slot, svc)
|
||||
slot.container_name = cname
|
||||
except Exception:
|
||||
logger.exception("startup_rdp_slot_start_failed slot_id=%s", slot.id)
|
||||
if slots:
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
thread = threading.Thread(target=cleanup_loop, daemon=True)
|
||||
thread.start()
|
||||
logger.info("maintenance_leader_started")
|
||||
+115
@@ -0,0 +1,115 @@
|
||||
import datetime as dt
|
||||
import enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean, DateTime, Enum, ForeignKey, Integer, String, Text, UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from database import Base
|
||||
|
||||
|
||||
class ServiceType(str, enum.Enum):
|
||||
WEB = "WEB"
|
||||
VNC = "VNC"
|
||||
RDP = "RDP"
|
||||
|
||||
|
||||
class SessionStatus(str, enum.Enum):
|
||||
ACTIVE = "ACTIVE"
|
||||
EXPIRED = "EXPIRED"
|
||||
TERMINATED = "TERMINATED"
|
||||
ROTATED = "ROTATED"
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
username: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255))
|
||||
expires_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), index=True)
|
||||
active: Mapped[bool] = mapped_column(Boolean, default=True, index=True)
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class Service(Base):
|
||||
__tablename__ = "services"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128))
|
||||
slug: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
type: Mapped[ServiceType] = mapped_column(Enum(ServiceType), index=True)
|
||||
target: Mapped[str] = mapped_column(Text)
|
||||
comment: Mapped[str] = mapped_column(Text, default="")
|
||||
svc_login: Mapped[str] = mapped_column(String(256), default="")
|
||||
svc_password: Mapped[str] = mapped_column(String(256), default="")
|
||||
svc_cred_hint: Mapped[str] = mapped_column(Text, default="")
|
||||
icon_path: Mapped[str] = mapped_column(Text, default="")
|
||||
active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
warm_pool_size: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = "categories"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(128), unique=True, index=True)
|
||||
slug: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class ServiceCategory(Base):
|
||||
__tablename__ = "service_categories"
|
||||
__table_args__ = (UniqueConstraint("service_id", "category_id", name="uq_service_category"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True)
|
||||
category_id: Mapped[int] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE"), index=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class UserServiceAccess(Base):
|
||||
__tablename__ = "user_service_access"
|
||||
__table_args__ = (UniqueConstraint("user_id", "service_id", name="uq_user_service"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True)
|
||||
granted_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class RdpSlot(Base):
|
||||
__tablename__ = "rdp_slots"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True)
|
||||
rdp_username: Mapped[str] = mapped_column(String(128))
|
||||
rdp_password: Mapped[str] = mapped_column(String(256), default="")
|
||||
container_name: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc))
|
||||
|
||||
|
||||
class SessionModel(Base):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), index=True)
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("services.id", ondelete="CASCADE"), index=True)
|
||||
status: Mapped[SessionStatus] = mapped_column(Enum(SessionStatus), default=SessionStatus.ACTIVE, index=True)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True)
|
||||
last_access_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True)
|
||||
container_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
|
||||
|
||||
class AuditLog(Base):
|
||||
__tablename__ = "audit_logs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
user_id: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
|
||||
action: Mapped[str] = mapped_column(String(128), index=True)
|
||||
details: Mapped[str] = mapped_column(Text)
|
||||
created_at: Mapped[dt.datetime] = mapped_column(DateTime(timezone=True), default=lambda: dt.datetime.now(dt.timezone.utc), index=True)
|
||||
+1131
File diff suppressed because it is too large
Load Diff
+181
@@ -0,0 +1,181 @@
|
||||
import datetime as dt
|
||||
import json
|
||||
import logging
|
||||
import contextvars
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
import mistune
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from markupsafe import Markup
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from config import (
|
||||
ICON_UPLOAD_MAX_BYTES, ICON_UPLOAD_TYPES, MAX_ACTIVE_SERVICES_PER_USER,
|
||||
SERVICE_ICONS_DIR, SESSION_IDLE_SECONDS,
|
||||
)
|
||||
from models import AuditLog, Category, ServiceCategory, SessionModel, SessionStatus
|
||||
|
||||
|
||||
logger = logging.getLogger("portal")
|
||||
request_id_ctx = contextvars.ContextVar("request_id", default="-")
|
||||
|
||||
|
||||
def _normalize_log_value(value):
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, dt.datetime):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
|
||||
def log_event(event: str, level: int = logging.INFO, **fields) -> None:
|
||||
payload = {"event": event, "req_id": request_id_ctx.get()}
|
||||
for key, value in fields.items():
|
||||
payload[key] = _normalize_log_value(value)
|
||||
logger.log(level, json.dumps(payload, ensure_ascii=False, separators=(",", ":")))
|
||||
|
||||
|
||||
def now_utc() -> dt.datetime:
|
||||
return dt.datetime.now(dt.timezone.utc)
|
||||
|
||||
|
||||
def session_closed_reason(sess: SessionModel, db: Session) -> str:
|
||||
if not sess:
|
||||
return "idle"
|
||||
if sess.status == SessionStatus.EXPIRED:
|
||||
return "idle"
|
||||
if sess.status == SessionStatus.ROTATED:
|
||||
return "limit"
|
||||
if sess.status == SessionStatus.TERMINATED:
|
||||
cutoff = now_utc() - dt.timedelta(seconds=SESSION_IDLE_SECONDS)
|
||||
active_rows = db.scalars(
|
||||
select(SessionModel).where(
|
||||
SessionModel.user_id == sess.user_id,
|
||||
SessionModel.status == SessionStatus.ACTIVE,
|
||||
SessionModel.last_access_at >= cutoff,
|
||||
)
|
||||
).all()
|
||||
active_service_ids = {row.service_id for row in active_rows}
|
||||
if len(active_service_ids) >= MAX_ACTIVE_SERVICES_PER_USER and sess.service_id not in active_service_ids:
|
||||
return "limit"
|
||||
return "idle"
|
||||
|
||||
|
||||
def normalize_web_target(url: str) -> str:
|
||||
raw = (url or "").strip()
|
||||
if not raw:
|
||||
return raw
|
||||
if raw.startswith(("http://", "https://")):
|
||||
return raw
|
||||
return f"http://{raw}"
|
||||
|
||||
|
||||
_md = mistune.create_markdown(
|
||||
escape=True,
|
||||
plugins=["strikethrough", "table", "task_lists"],
|
||||
)
|
||||
|
||||
|
||||
def format_service_comment(raw_text: str) -> Markup:
|
||||
raw = (raw_text or "").replace("\r\n", "\n").replace("\r", "\n").strip()
|
||||
if not raw:
|
||||
return Markup("")
|
||||
return Markup(_md(raw))
|
||||
|
||||
|
||||
def parse_rdp_target(target: str) -> dict:
|
||||
raw = (target or "").strip()
|
||||
if not raw:
|
||||
raise HTTPException(status_code=400, detail="Empty RDP target")
|
||||
|
||||
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise HTTPException(status_code=400, detail="Invalid RDP target. Use host:port or rdp://user:pass@host:port")
|
||||
port = parsed.port or 3389
|
||||
|
||||
username = unquote(parsed.username) if parsed.username else ""
|
||||
password = unquote(parsed.password) if parsed.password else ""
|
||||
|
||||
query = parse_qs(parsed.query or "")
|
||||
if not username:
|
||||
username = (query.get("u", [""])[0] or query.get("user", [""])[0] or "").strip()
|
||||
if not password:
|
||||
password = (query.get("p", [""])[0] or query.get("password", [""])[0] or "").strip()
|
||||
|
||||
domain = (query.get("domain", [""])[0] or query.get("d", [""])[0] or "").strip()
|
||||
security = (query.get("sec", [""])[0] or query.get("security", [""])[0] or "").strip().lower()
|
||||
if security and security not in {"nla", "tls", "rdp"}:
|
||||
raise HTTPException(status_code=400, detail="Invalid RDP security. Use one of: nla, tls, rdp")
|
||||
|
||||
return {
|
||||
"host": host,
|
||||
"port": str(port),
|
||||
"user": username,
|
||||
"password": password,
|
||||
"domain": domain,
|
||||
"security": security,
|
||||
}
|
||||
|
||||
|
||||
def set_service_categories(db: Session, service_id: int, category_ids: list[int]) -> None:
|
||||
normalized = sorted({int(x) for x in (category_ids or [])})
|
||||
if normalized:
|
||||
existing_ids = set(db.scalars(select(Category.id).where(Category.id.in_(normalized))).all())
|
||||
missing = sorted(set(normalized) - existing_ids)
|
||||
if missing:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown category ids: {missing}")
|
||||
|
||||
existing_links = db.scalars(select(ServiceCategory).where(ServiceCategory.service_id == service_id)).all()
|
||||
current = {row.category_id: row for row in existing_links}
|
||||
wanted = set(normalized)
|
||||
|
||||
for cat_id in wanted:
|
||||
if cat_id not in current:
|
||||
db.add(ServiceCategory(service_id=service_id, category_id=cat_id))
|
||||
for cat_id, row in current.items():
|
||||
if cat_id not in wanted:
|
||||
db.delete(row)
|
||||
|
||||
|
||||
def audit(db: Session, action: str, details: str, user_id: Optional[int] = None) -> None:
|
||||
db.add(AuditLog(user_id=user_id, action=action, details=details))
|
||||
db.commit()
|
||||
|
||||
|
||||
def ensure_icons_dir() -> None:
|
||||
SERVICE_ICONS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def remove_icon_file(icon_path: str) -> None:
|
||||
if not icon_path or not icon_path.startswith("/static/service-icons/"):
|
||||
return
|
||||
filename = icon_path.rsplit("/", 1)[-1]
|
||||
candidate = SERVICE_ICONS_DIR / filename
|
||||
try:
|
||||
candidate.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
logger.exception("icon_delete_failed path=%s", candidate)
|
||||
|
||||
|
||||
async def store_service_icon(service, upload: UploadFile) -> str:
|
||||
ensure_icons_dir()
|
||||
content_type = (upload.content_type or "").lower().strip()
|
||||
ext = ICON_UPLOAD_TYPES.get(content_type)
|
||||
if not ext:
|
||||
raise HTTPException(status_code=400, detail="Unsupported file type. Use PNG/JPG/WEBP")
|
||||
|
||||
payload = await upload.read(ICON_UPLOAD_MAX_BYTES + 1)
|
||||
if len(payload) > ICON_UPLOAD_MAX_BYTES:
|
||||
raise HTTPException(status_code=400, detail="File too large. Max 2MB")
|
||||
if not payload:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
stamp = dt.datetime.now(dt.timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"svc_{service.id}_{stamp}.{ext}"
|
||||
target = SERVICE_ICONS_DIR / filename
|
||||
target.write_bytes(payload)
|
||||
return f"/static/service-icons/{filename}"
|
||||
Reference in New Issue
Block a user