216 lines
8.5 KiB
Python
216 lines
8.5 KiB
Python
"""Database connection and session management."""
|
|
|
|
from collections.abc import Generator
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
try:
|
|
from .config import settings
|
|
from .models import Base
|
|
except ImportError:
|
|
from config import settings
|
|
from models import Base
|
|
|
|
|
|
def get_database_runtime_summary() -> dict[str, str]:
|
|
"""Return a human-readable summary of the effective database backend."""
|
|
if settings.use_sqlite:
|
|
db_path = str(Path(settings.SQLITE_DB_PATH or "/tmp/ai_software_factory_test.db").expanduser().resolve())
|
|
return {
|
|
"backend": "sqlite",
|
|
"target": db_path,
|
|
"database": db_path,
|
|
}
|
|
|
|
parsed = urlparse(settings.database_url)
|
|
database_name = parsed.path.lstrip("/") or "unknown"
|
|
host = parsed.hostname or "unknown-host"
|
|
port = str(parsed.port or 5432)
|
|
return {
|
|
"backend": parsed.scheme.split("+", 1)[0] or "postgresql",
|
|
"target": f"{host}:{port}/{database_name}",
|
|
"database": database_name,
|
|
}
|
|
|
|
|
|
def get_engine() -> Engine:
|
|
"""Create and return SQLAlchemy engine with connection pooling."""
|
|
# Use SQLite for tests, PostgreSQL for production
|
|
if settings.use_sqlite:
|
|
db_path = settings.SQLITE_DB_PATH or "/tmp/ai_software_factory_test.db"
|
|
Path(db_path).expanduser().resolve().parent.mkdir(parents=True, exist_ok=True)
|
|
db_url = f"sqlite:///{db_path}"
|
|
# SQLite-specific configuration - no pooling for SQLite
|
|
engine = create_engine(
|
|
db_url,
|
|
connect_args={"check_same_thread": False},
|
|
echo=settings.LOG_LEVEL == "DEBUG"
|
|
)
|
|
else:
|
|
db_url = settings.database_url
|
|
# PostgreSQL-specific configuration
|
|
engine = create_engine(
|
|
db_url,
|
|
pool_size=settings.DB_POOL_SIZE or 10,
|
|
max_overflow=settings.DB_MAX_OVERFLOW or 20,
|
|
pool_pre_ping=settings.LOG_LEVEL == "DEBUG",
|
|
echo=settings.LOG_LEVEL == "DEBUG",
|
|
pool_timeout=settings.DB_POOL_TIMEOUT or 30
|
|
)
|
|
|
|
return engine
|
|
|
|
|
|
def get_session() -> Generator[Session, None, None]:
|
|
"""Yield a managed database session."""
|
|
engine = get_engine()
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
"""Dependency for FastAPI routes that need database access."""
|
|
yield from get_session()
|
|
|
|
|
|
def get_db_sync() -> Session:
|
|
"""Get a database session directly (for non-FastAPI/NiceGUI usage)."""
|
|
engine = get_engine()
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
session = SessionLocal()
|
|
return session
|
|
|
|
|
|
def get_db_session() -> Session:
|
|
"""Get a database session directly (for non-FastAPI usage)."""
|
|
session = next(get_session())
|
|
return session
|
|
|
|
|
|
def get_alembic_config(database_url: str | None = None) -> Config:
|
|
"""Return an Alembic config bound to the active database URL."""
|
|
package_root = Path(__file__).resolve().parent
|
|
alembic_ini = package_root / "alembic.ini"
|
|
config = Config(str(alembic_ini))
|
|
config.set_main_option("script_location", str(package_root / "alembic"))
|
|
config.set_main_option("sqlalchemy.url", database_url or settings.database_url)
|
|
return config
|
|
|
|
|
|
def run_migrations(database_url: str | None = None) -> dict:
|
|
"""Apply Alembic migrations to the configured database."""
|
|
try:
|
|
config = get_alembic_config(database_url)
|
|
command.upgrade(config, "head")
|
|
return {"status": "success", "message": "Database migrations applied."}
|
|
except Exception as exc:
|
|
return {"status": "error", "message": str(exc)}
|
|
|
|
|
|
def init_db() -> dict:
|
|
"""Initialize database tables and database if needed."""
|
|
if settings.use_sqlite:
|
|
result = run_migrations()
|
|
if result["status"] == "success":
|
|
print("SQLite database migrations applied successfully.")
|
|
return {"status": "success", "message": "SQLite database initialized via migrations."}
|
|
engine = get_engine()
|
|
try:
|
|
Base.metadata.create_all(bind=engine)
|
|
print("SQLite database tables created successfully.")
|
|
return {"status": "success", "message": "SQLite database initialized with metadata fallback."}
|
|
except Exception as e:
|
|
print(f"Error initializing SQLite database: {str(e)}")
|
|
return {'status': 'error', 'message': f'Error: {str(e)}'}
|
|
else:
|
|
# PostgreSQL
|
|
db_url = settings.database_url
|
|
db_name = db_url.split('/')[-1] if '/' in db_url else 'ai_software_factory'
|
|
|
|
try:
|
|
# Create engine to check/create database
|
|
engine = create_engine(db_url)
|
|
|
|
# Try to create database if it doesn't exist
|
|
try:
|
|
with engine.connect() as conn:
|
|
# Check if database exists
|
|
result = conn.execute(text(f"SELECT 1 FROM {db_name} WHERE 1=0"))
|
|
# If no error, database exists
|
|
conn.commit()
|
|
print(f"PostgreSQL database '{db_name}' already exists.")
|
|
except Exception as e:
|
|
# Database doesn't exist or has different error - try to create it
|
|
error_msg = str(e).lower()
|
|
# Only create if it's a relation does not exist error or similar
|
|
if "does not exist" in error_msg or "database" in error_msg:
|
|
try:
|
|
conn = engine.connect()
|
|
conn.execute(text(f"CREATE DATABASE {db_name}"))
|
|
conn.commit()
|
|
print(f"PostgreSQL database '{db_name}' created.")
|
|
except Exception as db_error:
|
|
print(f"Could not create database: {str(db_error)}")
|
|
# Try to connect anyway - maybe using existing db name
|
|
engine = create_engine(db_url.replace(f'/{db_name}', '/postgres'))
|
|
with engine.connect() as conn:
|
|
# Just create tables in postgres database for now
|
|
print(f"Using existing 'postgres' database.")
|
|
|
|
migration_result = run_migrations(db_url)
|
|
if migration_result["status"] == "success":
|
|
print(f"PostgreSQL database '{db_name}' migrations applied successfully.")
|
|
return {'status': 'success', 'message': f'PostgreSQL database "{db_name}" initialized via migrations.'}
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
print(f"PostgreSQL database '{db_name}' tables created successfully.")
|
|
return {'status': 'success', 'message': f'PostgreSQL database "{db_name}" initialized with metadata fallback.'}
|
|
|
|
except Exception as e:
|
|
print(f"Error initializing PostgreSQL database: {str(e)}")
|
|
return {'status': 'error', 'message': f'Error: {str(e)}'}
|
|
|
|
|
|
def drop_db() -> dict:
|
|
"""Drop all database tables (use with caution!)."""
|
|
if settings.use_sqlite:
|
|
engine = get_engine()
|
|
try:
|
|
Base.metadata.drop_all(bind=engine)
|
|
print("SQLite database tables dropped successfully.")
|
|
return {'status': 'success', 'message': 'SQLite tables dropped.'}
|
|
except Exception as e:
|
|
print(f"Error dropping SQLite tables: {str(e)}")
|
|
return {'status': 'error', 'message': str(e)}
|
|
else:
|
|
db_url = settings.database_url
|
|
db_name = db_url.split('/')[-1] if '/' in db_url else 'ai_software_factory'
|
|
|
|
try:
|
|
engine = create_engine(db_url)
|
|
Base.metadata.drop_all(bind=engine)
|
|
print(f"PostgreSQL database '{db_name}' tables dropped successfully.")
|
|
return {'status': 'success', 'message': f'PostgreSQL "{db_name}" tables dropped.'}
|
|
except Exception as e:
|
|
print(f"Error dropping PostgreSQL tables: {str(e)}")
|
|
return {'status': 'error', 'message': str(e)}
|
|
|
|
|
|
def create_migration_script() -> str:
|
|
"""Generate a migration script for database schema changes."""
|
|
return """See ai_software_factory/alembic/versions for managed schema migrations.""" |