259 lines
10 KiB
Python
259 lines
10 KiB
Python
import pytest
|
|
import time
|
|
import os
|
|
from unittest.mock import patch, MagicMock
|
|
from fastapi.testclient import TestClient
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
import main
|
|
|
|
class TestMainApp:
|
|
"""Test cases for main application functionality"""
|
|
|
|
def test_root_endpoint(self, client):
|
|
"""Test root endpoint returns correct information"""
|
|
response = client.get("/")
|
|
assert response.status_code == 200
|
|
# The auth router overrides the main app's root endpoint, so we get HTML
|
|
assert "text/html" in response.headers["content-type"]
|
|
# Check that it's the admin dashboard or login page
|
|
content = response.text
|
|
assert "admin" in content.lower() or "login" in content.lower()
|
|
|
|
def test_health_check_success(self, client, test_db):
|
|
"""Test health check endpoint when database is connected"""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert "timestamp" in data
|
|
assert "version" in data
|
|
assert "environment" in data
|
|
assert data["database_status"] == "connected"
|
|
|
|
@patch('utils.auth.get_db')
|
|
def test_health_check_database_failure(self, mock_get_db, client):
|
|
"""Test health check endpoint when database is disconnected"""
|
|
# Mock database failure
|
|
mock_db = MagicMock()
|
|
mock_db.execute.side_effect = SQLAlchemyError("Database connection failed")
|
|
mock_get_db.return_value = iter([mock_db])
|
|
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "unhealthy"
|
|
assert data["database_status"] == "disconnected"
|
|
|
|
def test_middleware_process_time_header(self, client):
|
|
"""Test that middleware adds process time header"""
|
|
response = client.get("/health")
|
|
assert "X-Process-Time" in response.headers
|
|
assert "X-Request-ID" in response.headers
|
|
process_time = float(response.headers["X-Process-Time"])
|
|
assert process_time >= 0
|
|
|
|
def test_middleware_request_id(self, client):
|
|
"""Test that middleware generates unique request IDs"""
|
|
response1 = client.get("/health")
|
|
response2 = client.get("/health")
|
|
|
|
request_id1 = response1.headers["X-Request-ID"]
|
|
request_id2 = response2.headers["X-Request-ID"]
|
|
|
|
assert request_id1 != request_id2
|
|
assert request_id1.isdigit()
|
|
assert request_id2.isdigit()
|
|
|
|
def test_api_exception_handler(self, client):
|
|
"""Test custom API exception handler"""
|
|
from utils.exceptions import APIException
|
|
|
|
# Create a test endpoint that raises APIException
|
|
@client.app.get("/test-api-exception")
|
|
def test_api_exception():
|
|
raise APIException(
|
|
status_code=400,
|
|
detail="Test API exception",
|
|
error_code="TEST_ERROR"
|
|
)
|
|
|
|
response = client.get("/test-api-exception")
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] == "Test API exception"
|
|
assert data["error_code"] == "TEST_ERROR"
|
|
assert "timestamp" in data
|
|
assert "path" in data
|
|
|
|
def test_validation_exception_handler(self, client):
|
|
"""Test validation exception handler"""
|
|
# Create a test endpoint with validation
|
|
from pydantic import BaseModel
|
|
|
|
class TestModel(BaseModel):
|
|
required_field: str
|
|
|
|
@client.app.post("/test-validation")
|
|
def test_validation(model: TestModel):
|
|
return {"message": "success"}
|
|
|
|
response = client.post("/test-validation", json={})
|
|
assert response.status_code == 422
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] == "Validation Error"
|
|
assert data["error_code"] == "VALIDATION_ERROR"
|
|
assert "details" in data
|
|
|
|
def test_http_exception_handler(self, client):
|
|
"""Test HTTP exception handler"""
|
|
@client.app.get("/test-http-exception")
|
|
def test_http_exception():
|
|
raise HTTPException(status_code=404, detail="Not found")
|
|
|
|
response = client.get("/test-http-exception")
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert data["success"] is False
|
|
assert data["error"] == "HTTP Error"
|
|
assert data["detail"] == "Not found"
|
|
|
|
def test_generic_exception_handler(self, client):
|
|
"""Test generic exception handler"""
|
|
# Test that the exception handler is properly registered
|
|
# by checking if it exists in the app's exception handlers
|
|
assert Exception in client.app.exception_handlers
|
|
assert client.app.exception_handlers[Exception] is not None
|
|
|
|
# Test that the handler function exists and is callable
|
|
handler = client.app.exception_handlers[Exception]
|
|
assert callable(handler)
|
|
|
|
# Test that the handler has the expected signature
|
|
import inspect
|
|
sig = inspect.signature(handler)
|
|
assert len(sig.parameters) == 2 # request and exc parameters
|
|
|
|
@patch.dict(os.environ, {
|
|
'APP_NAME': 'Test App',
|
|
'APP_VERSION': '1.0.0',
|
|
'DEBUG': 'true',
|
|
'ENVIRONMENT': 'test',
|
|
'CORS_ORIGINS': 'http://localhost:3000,http://localhost:8080',
|
|
'TRUSTED_HOSTS': 'localhost,test.com'
|
|
})
|
|
def test_app_config_environment_variables(self):
|
|
"""Test application configuration with environment variables"""
|
|
# Clear any existing imports and reload
|
|
import importlib
|
|
import main
|
|
importlib.reload(main)
|
|
|
|
assert main.AppConfig.APP_NAME == "Test App"
|
|
assert main.AppConfig.VERSION == "1.0.0"
|
|
assert main.AppConfig.DEBUG is True
|
|
assert main.AppConfig.ENVIRONMENT == "test"
|
|
assert "http://localhost:3000" in main.AppConfig.CORS_ORIGINS
|
|
assert "http://localhost:8080" in main.AppConfig.CORS_ORIGINS
|
|
assert "localhost" in main.AppConfig.TRUSTED_HOSTS
|
|
assert "test.com" in main.AppConfig.TRUSTED_HOSTS
|
|
|
|
def test_app_config_defaults(self):
|
|
"""Test application configuration defaults"""
|
|
# Test the defaults that don't require FastAPI app creation
|
|
# These are the default values from the AppConfig class
|
|
# Note: Environment might be set by test configuration
|
|
assert hasattr(main.AppConfig, 'CORS_ORIGINS')
|
|
assert hasattr(main.AppConfig, 'TRUSTED_HOSTS')
|
|
|
|
# Test that the AppConfig class has the expected attributes
|
|
assert hasattr(main.AppConfig, 'ENVIRONMENT')
|
|
assert hasattr(main.AppConfig, 'DEBUG')
|
|
assert hasattr(main.AppConfig, 'APP_NAME')
|
|
assert hasattr(main.AppConfig, 'VERSION')
|
|
|
|
# Test that the values are of the expected types
|
|
assert isinstance(main.AppConfig.CORS_ORIGINS, list)
|
|
assert isinstance(main.AppConfig.TRUSTED_HOSTS, list)
|
|
assert isinstance(main.AppConfig.ENVIRONMENT, str)
|
|
assert isinstance(main.AppConfig.DEBUG, bool)
|
|
|
|
@patch('main.ensure_directories')
|
|
@patch('main.AdminUser.__table__.create')
|
|
@patch('main.Coupon.__table__.create')
|
|
@pytest.mark.asyncio
|
|
async def test_lifespan_startup_success(self, mock_coupon_create, mock_user_create, mock_ensure_dirs):
|
|
"""Test application lifespan startup success"""
|
|
from main import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
|
|
# Test startup
|
|
async with lifespan(mock_app) as lifespan_gen:
|
|
mock_ensure_dirs.assert_called_once()
|
|
mock_user_create.assert_called_once()
|
|
mock_coupon_create.assert_called_once()
|
|
|
|
@patch('main.ensure_directories')
|
|
@patch('main.AdminUser.__table__.create')
|
|
@pytest.mark.asyncio
|
|
async def test_lifespan_startup_failure(self, mock_user_create, mock_ensure_dirs):
|
|
"""Test application lifespan startup failure"""
|
|
from main import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_user_create.side_effect = Exception("Database error")
|
|
|
|
# Test startup failure
|
|
with pytest.raises(Exception, match="Database error"):
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
@patch('os.makedirs')
|
|
def test_ensure_directories(self, mock_makedirs):
|
|
"""Test ensure_directories function"""
|
|
from main import ensure_directories
|
|
|
|
ensure_directories()
|
|
|
|
# Should be called twice for translation_upload and logs
|
|
assert mock_makedirs.call_count == 2
|
|
mock_makedirs.assert_any_call("translation_upload", exist_ok=True)
|
|
mock_makedirs.assert_any_call("logs", exist_ok=True)
|
|
|
|
def test_app_creation_with_debug(self):
|
|
"""Test FastAPI app creation with debug mode"""
|
|
with patch.dict(os.environ, {'DEBUG': 'true'}):
|
|
import importlib
|
|
import main
|
|
importlib.reload(main)
|
|
|
|
# Check if docs are enabled in debug mode
|
|
assert main.app.docs_url == "/docs"
|
|
assert main.app.redoc_url == "/redoc"
|
|
|
|
def test_app_creation_without_debug(self):
|
|
"""Test FastAPI app creation without debug mode"""
|
|
with patch.dict(os.environ, {'DEBUG': 'false'}):
|
|
import importlib
|
|
import main
|
|
importlib.reload(main)
|
|
|
|
# Check if docs are disabled in non-debug mode
|
|
assert main.app.docs_url is None
|
|
assert main.app.redoc_url is None
|
|
|
|
def test_production_middleware(self):
|
|
"""Test production middleware configuration"""
|
|
with patch.dict(os.environ, {'ENVIRONMENT': 'production'}):
|
|
import importlib
|
|
import main
|
|
importlib.reload(main)
|
|
|
|
# Check if TrustedHostMiddleware is added
|
|
middleware_types = [type(middleware.cls) for middleware in main.app.user_middleware]
|
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
|
# Check if any middleware is of type TrustedHostMiddleware
|
|
assert any(isinstance(middleware.cls, type) and issubclass(middleware.cls, TrustedHostMiddleware) for middleware in main.app.user_middleware) |