329 lines
10 KiB
Python
329 lines
10 KiB
Python
from fastapi import FastAPI, Request, status
|
|
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
|
from fastapi.exceptions import RequestValidationError
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
import time
|
|
import os
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Dict, Any
|
|
from routes import auth
|
|
from utils.logger import setup_logger
|
|
from utils.exceptions import APIException, handle_api_exception
|
|
from models.user import AdminUser
|
|
from models.coupon import Coupon
|
|
from utils.auth import engine
|
|
from init_db import initialize_database
|
|
|
|
# Setup logging
|
|
logger = setup_logger(__name__)
|
|
|
|
# Application configuration
|
|
class AppConfig:
|
|
"""Application configuration class"""
|
|
APP_NAME = os.getenv("APP_NAME")
|
|
VERSION = os.getenv("APP_VERSION")
|
|
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
|
|
ENVIRONMENT = os.getenv("ENVIRONMENT", "development")
|
|
|
|
# CORS settings - parse comma-separated string
|
|
_cors_origins_str = os.getenv("CORS_ORIGINS", "")
|
|
CORS_ORIGINS = [origin.strip() for origin in _cors_origins_str.split(",") if origin.strip()] if _cors_origins_str else []
|
|
|
|
# Trusted hosts for production
|
|
_trusted_hosts_str = os.getenv("TRUSTED_HOSTS", "*")
|
|
TRUSTED_HOSTS = [host.strip() for host in _trusted_hosts_str.split(",") if host.strip()] if _trusted_hosts_str != "*" else ["*"]
|
|
|
|
# Application lifespan manager
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application startup and shutdown events"""
|
|
# Startup
|
|
logger.info(
|
|
"Application starting up",
|
|
extra={
|
|
"app_name": AppConfig.APP_NAME,
|
|
"version": AppConfig.VERSION,
|
|
"environment": AppConfig.ENVIRONMENT,
|
|
"debug": AppConfig.DEBUG
|
|
}
|
|
)
|
|
|
|
# Ensure required directories exist
|
|
ensure_directories()
|
|
|
|
# Initialize database: create tables and default admin user
|
|
try:
|
|
initialize_database()
|
|
except Exception as e:
|
|
logger.error(f"Error initializing database: {e}")
|
|
raise
|
|
|
|
yield
|
|
# Shutdown
|
|
logger.info("Application shutting down")
|
|
|
|
def ensure_directories():
|
|
"""Ensure required directories exist"""
|
|
directories = [
|
|
"translation_upload",
|
|
"logs"
|
|
]
|
|
|
|
for directory in directories:
|
|
os.makedirs(directory, exist_ok=True)
|
|
logger.debug(f"Ensured directory exists: {directory}")
|
|
|
|
# Create FastAPI application with enterprise features
|
|
app = FastAPI(
|
|
title=AppConfig.APP_NAME,
|
|
version=AppConfig.VERSION,
|
|
description="Enterprise-grade Ebook Coupon Management System API",
|
|
docs_url="/docs" if AppConfig.DEBUG else None,
|
|
redoc_url="/redoc" if AppConfig.DEBUG else None,
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# Get paths relative to backend/main.py
|
|
BASE_DIR = os.path.dirname(__file__)
|
|
PARENT_DIR = os.path.abspath(os.path.join(BASE_DIR, ".."))
|
|
ADMIN_PANEL_DIR = os.path.join(PARENT_DIR, "admin-frontend")
|
|
|
|
# Mount static files
|
|
app.mount("/static", StaticFiles(directory=ADMIN_PANEL_DIR), name="static")
|
|
|
|
# Setup templates
|
|
templates = Jinja2Templates(directory=ADMIN_PANEL_DIR)
|
|
|
|
# Add middleware for production readiness
|
|
if AppConfig.ENVIRONMENT == "production":
|
|
# Trusted host middleware for production security
|
|
app.add_middleware(
|
|
TrustedHostMiddleware,
|
|
allowed_hosts=AppConfig.TRUSTED_HOSTS
|
|
)
|
|
|
|
# CORS middleware for cross-origin requests
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=AppConfig.CORS_ORIGINS,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Request timing and logging middleware
|
|
@app.middleware("http")
|
|
async def add_process_time_header(request: Request, call_next):
|
|
"""Add request processing time and logging"""
|
|
start_time = time.time()
|
|
|
|
# Generate request ID for tracking
|
|
request_id = f"{int(start_time * 1000)}"
|
|
request.state.request_id = request_id
|
|
|
|
# Log incoming request
|
|
logger.info(
|
|
f"Incoming request: {request.method} {request.url.path}",
|
|
extra={
|
|
"request_id": request_id,
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"client_ip": request.client.host,
|
|
"user_agent": request.headers.get("user-agent", "")
|
|
}
|
|
)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
process_time = time.time() - start_time
|
|
|
|
# Add headers for monitoring
|
|
response.headers["X-Process-Time"] = f"{process_time:.4f}"
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
# Log successful response
|
|
logger.info(
|
|
f"Request completed: {request.method} {request.url.path}",
|
|
extra={
|
|
"request_id": request_id,
|
|
"status_code": response.status_code,
|
|
"process_time": process_time
|
|
}
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
process_time = time.time() - start_time
|
|
logger.error(
|
|
f"Request failed: {request.method} {request.url.path}",
|
|
extra={
|
|
"request_id": request_id,
|
|
"error": str(e),
|
|
"process_time": process_time
|
|
},
|
|
exc_info=True
|
|
)
|
|
raise
|
|
|
|
# Exception handlers for proper error responses
|
|
@app.exception_handler(APIException)
|
|
async def api_exception_handler(request: Request, exc: APIException):
|
|
"""Handle custom API exceptions"""
|
|
logger.warning(
|
|
f"API Exception: {exc.detail}",
|
|
extra={
|
|
"request_id": getattr(request.state, "request_id", "unknown"),
|
|
"status_code": exc.status_code,
|
|
"path": request.url.path
|
|
}
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={
|
|
"success": False,
|
|
"error": exc.detail,
|
|
"error_code": exc.error_code,
|
|
"timestamp": time.time(),
|
|
"path": str(request.url.path)
|
|
}
|
|
)
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
|
"""Handle validation errors"""
|
|
# Safely extract error details
|
|
try:
|
|
error_details = []
|
|
for error in exc.errors():
|
|
safe_error = {
|
|
"type": error.get("type", "unknown"),
|
|
"loc": error.get("loc", []),
|
|
"msg": str(error.get("msg", "Unknown error")),
|
|
"input": str(error.get("input", "Unknown input"))
|
|
}
|
|
if "ctx" in error and error["ctx"]:
|
|
safe_error["ctx"] = {k: str(v) for k, v in error["ctx"].items()}
|
|
error_details.append(safe_error)
|
|
except Exception:
|
|
error_details = [{"type": "validation_error", "msg": "Request validation failed"}]
|
|
|
|
logger.warning(
|
|
"Validation error",
|
|
extra={
|
|
"request_id": getattr(request.state, "request_id", "unknown"),
|
|
"errors": error_details,
|
|
"path": request.url.path
|
|
}
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=422,
|
|
content={
|
|
"success": False,
|
|
"error": "Validation Error",
|
|
"error_code": "VALIDATION_ERROR",
|
|
"detail": "Request validation failed",
|
|
"timestamp": time.time(),
|
|
"path": str(request.url.path),
|
|
"details": error_details
|
|
}
|
|
)
|
|
|
|
@app.exception_handler(StarletteHTTPException)
|
|
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
|
"""Handle HTTP exceptions"""
|
|
logger.warning(
|
|
f"HTTP Exception: {exc.status_code}",
|
|
extra={
|
|
"request_id": getattr(request.state, "request_id", "unknown"),
|
|
"status_code": exc.status_code,
|
|
"detail": exc.detail,
|
|
"path": request.url.path
|
|
}
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={
|
|
"success": False,
|
|
"error": "HTTP Error",
|
|
"detail": exc.detail,
|
|
"timestamp": time.time(),
|
|
"path": str(request.url.path)
|
|
}
|
|
)
|
|
|
|
@app.exception_handler(Exception)
|
|
async def generic_exception_handler(request: Request, exc: Exception):
|
|
"""Handle generic exceptions"""
|
|
logger.error(
|
|
"Unhandled exception",
|
|
extra={
|
|
"request_id": getattr(request.state, "request_id", "unknown"),
|
|
"exception_type": type(exc).__name__,
|
|
"exception_message": str(exc),
|
|
"path": request.url.path
|
|
},
|
|
exc_info=True
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"success": False,
|
|
"error": "Internal Server Error",
|
|
"error_code": "INTERNAL_ERROR",
|
|
"detail": "An unexpected error occurred",
|
|
"timestamp": time.time(),
|
|
"path": str(request.url.path)
|
|
}
|
|
)
|
|
|
|
# Health check endpoint
|
|
@app.get("/health", tags=["Health"])
|
|
async def health_check() -> Dict[str, Any]:
|
|
"""Health check endpoint for monitoring"""
|
|
from utils.auth import get_db
|
|
from sqlalchemy import text
|
|
|
|
# Check database connection
|
|
db_status = "connected"
|
|
try:
|
|
db = next(get_db())
|
|
db.execute(text("SELECT 1"))
|
|
db.close()
|
|
except Exception as e:
|
|
db_status = "disconnected"
|
|
logger.error("Database health check failed", extra={"error": str(e)})
|
|
|
|
return {
|
|
"status": "healthy" if db_status == "connected" else "unhealthy",
|
|
"timestamp": time.time(),
|
|
"version": AppConfig.VERSION,
|
|
"environment": AppConfig.ENVIRONMENT,
|
|
"database_status": db_status
|
|
}
|
|
|
|
# Include routers
|
|
app.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
|
app.include_router(auth.router, prefix="", tags=["Auth"])
|
|
|
|
# Root endpoint
|
|
@app.get("/", tags=["Root"])
|
|
async def root() -> Dict[str, Any]:
|
|
"""Root endpoint with API information"""
|
|
return {
|
|
"message": AppConfig.APP_NAME,
|
|
"version": AppConfig.VERSION,
|
|
"environment": AppConfig.ENVIRONMENT,
|
|
"docs_url": "/docs" if AppConfig.DEBUG else None,
|
|
"health_check": "/health"
|
|
}
|