254 lines
7.1 KiB
Python
254 lines
7.1 KiB
Python
"""
|
||
Database Initialization Script
|
||
|
||
This script automatically initializes the database on application startup:
|
||
- Creates all required tables if they don't exist
|
||
- Creates default admin user if no admin exists
|
||
- Runs automatically when the application starts
|
||
- Safe to run multiple times (idempotent)
|
||
|
||
Usage:
|
||
This file is automatically called from main.py lifespan event.
|
||
No manual execution required.
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy.exc import IntegrityError
|
||
from dotenv import load_dotenv
|
||
|
||
from utils.auth import engine, SessionLocal, Base, hash_password
|
||
from models.user import AdminUser
|
||
from models.coupon import Coupon
|
||
|
||
# Load environment variables
|
||
load_dotenv()
|
||
|
||
# Setup logger
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def create_tables():
|
||
"""
|
||
Create all database tables if they don't exist.
|
||
|
||
This function creates tables for:
|
||
- AdminUser (admin_users table)
|
||
- Coupon (coupon_codes table)
|
||
|
||
Returns:
|
||
bool: True if successful, False otherwise
|
||
"""
|
||
try:
|
||
# Import all models to ensure they're registered with Base
|
||
from models.user import AdminUser
|
||
from models.coupon import Coupon
|
||
|
||
# Create all tables
|
||
Base.metadata.create_all(bind=engine)
|
||
logger.info("✅ Database tables created/verified successfully")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ Error creating database tables: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
def create_default_admin(db: Session) -> bool:
|
||
"""
|
||
Create default admin user if no admin exists in the database.
|
||
|
||
Reads credentials from environment variables:
|
||
- ADMIN_USERNAME (default: 'admin')
|
||
- ADMIN_PASSWORD (default: 'admin123')
|
||
|
||
Args:
|
||
db (Session): Database session
|
||
|
||
Returns:
|
||
bool: True if admin was created or already exists, False on error
|
||
"""
|
||
try:
|
||
# Check if any admin user exists
|
||
existing_admin = db.query(AdminUser).first()
|
||
|
||
if existing_admin:
|
||
logger.info(f"ℹ️ Admin user already exists: {existing_admin.username}")
|
||
return True
|
||
|
||
# Get admin credentials from environment variables
|
||
admin_username = os.getenv("ADMIN_USERNAME", "admin")
|
||
admin_password = os.getenv("ADMIN_PASSWORD", "admin123")
|
||
|
||
# Validate credentials
|
||
if not admin_username or not admin_password:
|
||
logger.error("❌ ADMIN_USERNAME or ADMIN_PASSWORD not set in environment variables")
|
||
return False
|
||
|
||
# Hash the password
|
||
password_hash = hash_password(admin_password)
|
||
|
||
# Create admin user
|
||
admin_user = AdminUser(
|
||
username=admin_username,
|
||
password_hash=password_hash
|
||
)
|
||
|
||
db.add(admin_user)
|
||
db.commit()
|
||
db.refresh(admin_user)
|
||
|
||
logger.info(f"✅ Default admin user created successfully: {admin_username}")
|
||
logger.warning("⚠️ Please change the default admin password in production!")
|
||
|
||
return True
|
||
|
||
except IntegrityError as e:
|
||
db.rollback()
|
||
logger.warning(f"⚠️ Admin user might already exist: {e}")
|
||
return True # Not a critical error, admin might exist
|
||
|
||
except Exception as e:
|
||
db.rollback()
|
||
logger.error(f"❌ Error creating default admin user: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
def initialize_database():
|
||
"""
|
||
Main initialization function that orchestrates database setup.
|
||
|
||
This function:
|
||
1. Creates all required database tables
|
||
2. Creates default admin user if none exists
|
||
3. Logs all operations for monitoring
|
||
|
||
Returns:
|
||
bool: True if initialization successful, False otherwise
|
||
|
||
Raises:
|
||
Exception: If critical initialization fails
|
||
"""
|
||
logger.info("🚀 Starting database initialization...")
|
||
|
||
# Step 1: Create tables
|
||
if not create_tables():
|
||
logger.error("❌ Failed to create database tables")
|
||
raise Exception("Database table creation failed")
|
||
|
||
# Step 2: Create default admin user
|
||
db = SessionLocal()
|
||
try:
|
||
if not create_default_admin(db):
|
||
logger.warning("⚠️ Failed to create default admin user")
|
||
# Don't raise exception, app can still run
|
||
|
||
logger.info("✅ Database initialization completed successfully")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ Database initialization failed: {e}", exc_info=True)
|
||
raise
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def verify_database_connection():
|
||
"""
|
||
Verify that database connection is working.
|
||
|
||
Returns:
|
||
bool: True if connection successful, False otherwise
|
||
"""
|
||
try:
|
||
from sqlalchemy import text
|
||
db = SessionLocal()
|
||
db.execute(text("SELECT 1"))
|
||
db.close()
|
||
logger.info("✅ Database connection verified")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ Database connection failed: {e}", exc_info=True)
|
||
return False
|
||
|
||
|
||
def get_admin_stats(db: Session) -> dict:
|
||
"""
|
||
Get statistics about the database for logging purposes.
|
||
|
||
Args:
|
||
db (Session): Database session
|
||
|
||
Returns:
|
||
dict: Statistics including admin count, coupon count, etc.
|
||
"""
|
||
try:
|
||
admin_count = db.query(AdminUser).count()
|
||
coupon_count = db.query(Coupon).count()
|
||
|
||
return {
|
||
"admin_users": admin_count,
|
||
"total_coupons": coupon_count,
|
||
"database_healthy": True
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting database stats: {e}")
|
||
return {
|
||
"database_healthy": False,
|
||
"error": str(e)
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
Allow manual execution for testing purposes.
|
||
|
||
Usage:
|
||
python init_db.py
|
||
"""
|
||
# Setup basic logging for standalone execution
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
|
||
print("=" * 60)
|
||
print("DATABASE INITIALIZATION SCRIPT")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
# Verify connection
|
||
if not verify_database_connection():
|
||
print("❌ Cannot connect to database. Please check your DATABASE_URL")
|
||
exit(1)
|
||
|
||
# Initialize database
|
||
try:
|
||
initialize_database() # noqa: E722
|
||
|
||
# Show stats
|
||
db = SessionLocal()
|
||
stats = get_admin_stats(db)
|
||
db.close()
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print("DATABASE STATISTICS")
|
||
print("=" * 60)
|
||
print(f"Admin Users: {stats.get('admin_users', 0)}")
|
||
print(f"Total Coupons: {stats.get('total_coupons', 0)}")
|
||
print(f"Status: {'✅ Healthy' if stats.get('database_healthy') else '❌ Unhealthy'}")
|
||
print("=" * 60)
|
||
print()
|
||
print("✅ Database initialization completed successfully!")
|
||
print()
|
||
|
||
except Exception as e:
|
||
print(f"\n❌ Initialization failed: {e}\n")
|
||
exit(1)
|
||
|