Spaces:
Paused
Paused
| """ | |
| Database initialization for the application. | |
| This script checks if the database is initialized and creates tables if needed. | |
| It's meant to be imported and run at application startup. | |
| """ | |
| import os | |
| import logging | |
| import asyncio | |
| from sqlalchemy.ext.asyncio import create_async_engine | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.future import select | |
| import subprocess | |
| import sys | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Database URL from environment | |
| db_url = os.getenv("DATABASE_URL", "") | |
| if db_url.startswith("postgresql://"): | |
| # Remove sslmode parameter if present which causes issues with asyncpg | |
| if "?" in db_url: | |
| base_url, params = db_url.split("?", 1) | |
| param_list = params.split("&") | |
| filtered_params = [p for p in param_list if not p.startswith("sslmode=")] | |
| if filtered_params: | |
| db_url = f"{base_url}?{'&'.join(filtered_params)}" | |
| else: | |
| db_url = base_url | |
| ASYNC_DATABASE_URL = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) | |
| else: | |
| ASYNC_DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost:5432/postgres" | |
| async def check_db_initialized(): | |
| """Check if the database is initialized with required tables.""" | |
| try: | |
| engine = create_async_engine( | |
| ASYNC_DATABASE_URL, | |
| echo=False, | |
| ) | |
| # Create session factory | |
| async_session = sessionmaker( | |
| engine, | |
| class_=AsyncSession, | |
| expire_on_commit=False | |
| ) | |
| async with async_session() as session: | |
| # Try to query tables | |
| # Replace with actual table names once you've defined them | |
| try: | |
| # Check if the 'users' table exists | |
| from sqlalchemy import text | |
| query = text("SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'users')") | |
| result = await session.execute(query) | |
| exists = result.scalar() | |
| if exists: | |
| logger.info("Database is initialized.") | |
| return True | |
| else: | |
| logger.warning("Database tables are not initialized.") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error checking tables: {e}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to connect to database: {e}") | |
| return False | |
| def initialize_database(): | |
| """Initialize the database with required tables.""" | |
| try: | |
| # Call the init_db.py script | |
| logger.info("Initializing database...") | |
| # Get the current directory | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| script_path = os.path.join(current_dir, "scripts", "init_db.py") | |
| # Run the script using the current Python interpreter | |
| result = subprocess.run([sys.executable, script_path], capture_output=True, text=True) | |
| if result.returncode == 0: | |
| logger.info("Database initialized successfully.") | |
| logger.debug(result.stdout) | |
| return True | |
| else: | |
| logger.error(f"Failed to initialize database: {result.stderr}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error initializing database: {e}") | |
| return False | |
| def ensure_database_initialized(): | |
| """Ensure the database is initialized with required tables.""" | |
| is_initialized = asyncio.run(check_db_initialized()) | |
| if not is_initialized: | |
| return initialize_database() | |
| return True |