Files
ai-software-factory/ai_software_factory/database.py

126 lines
4.5 KiB
Python

"""Database connection and session management."""
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker, Session
from config import settings
from models import Base
def get_engine() -> create_engine:
"""Create and return SQLAlchemy engine with connection pooling."""
# Use SQLite for tests, PostgreSQL for production
if settings.USE_SQLITE:
db_path = settings.SQLITE_DB_PATH or "/tmp/ai_software_factory_test.db"
db_url = f"sqlite:///{db_path}"
# SQLite-specific configuration - no pooling for SQLite
engine = create_engine(
db_url,
connect_args={"check_same_thread": False},
echo=settings.LOG_LEVEL == "DEBUG"
)
else:
db_url = settings.POSTGRES_URL or settings.database_url
# PostgreSQL-specific configuration
engine = create_engine(
db_url,
pool_size=settings.DB_POOL_SIZE or 10,
max_overflow=settings.DB_MAX_OVERFLOW or 20,
pool_pre_ping=settings.LOG_LEVEL == "DEBUG",
echo=settings.LOG_LEVEL == "DEBUG",
pool_timeout=settings.DB_POOL_TIMEOUT or 30
)
# Event listener for connection checkout (PostgreSQL only)
if not settings.USE_SQLITE:
@event.listens_for(engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
"""Log connection checkout for audit purposes."""
if settings.LOG_LEVEL in ("DEBUG", "INFO"):
print(f"DB Connection checked out from pool")
@event.listens_for(engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
"""Log connection checkin for audit purposes."""
if settings.LOG_LEVEL == "DEBUG":
print(f"DB Connection returned to pool")
return engine
def get_session() -> Session:
"""Create and return database session factory."""
engine = get_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def session_factory() -> Session:
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
return session_factory
def init_db() -> None:
"""Initialize database tables."""
engine = get_engine()
Base.metadata.create_all(bind=engine)
print("Database tables created successfully.")
def drop_db() -> None:
"""Drop all database tables (use with caution!)."""
engine = get_engine()
Base.metadata.drop_all(bind=engine)
print("Database tables dropped successfully.")
def get_db() -> Session:
"""Dependency for FastAPI routes that need database access."""
engine = get_engine()
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = SessionLocal()
try:
yield session
finally:
session.close()
def get_db_session() -> Session:
"""Get a database session directly (for non-FastAPI usage)."""
session = next(get_session())
return session
def create_migration_script() -> str:
"""Generate a migration script for database schema changes."""
return '''-- Migration script for AI Software Factory database
-- Generated automatically - review before applying
-- Add new columns to existing tables if needed
-- This is a placeholder for future migrations
-- Example: Add audit_trail_index for better query performance
CREATE INDEX IF NOT EXISTS idx_audit_trail_timestamp ON audit_trail(timestamp);
CREATE INDEX IF NOT EXISTS idx_audit_trail_action ON audit_trail(action);
CREATE INDEX IF NOT EXISTS idx_audit_trail_project ON audit_trail(project_id);
-- Example: Add user_actions_index for better query performance
CREATE INDEX IF NOT EXISTS idx_user_actions_timestamp ON user_actions(timestamp);
CREATE INDEX IF NOT EXISTS idx_user_actions_actor ON user_actions(actor_type, actor_name);
CREATE INDEX IF NOT EXISTS idx_user_actions_history ON user_actions(history_id);
-- Example: Add project_logs_index for better query performance
CREATE INDEX IF NOT EXISTS idx_project_logs_timestamp ON project_logs(timestamp);
CREATE INDEX IF NOT EXISTS idx_project_logs_level ON project_logs(log_level);
-- Example: Add system_logs_index for better query performance
CREATE INDEX IF NOT EXISTS idx_system_logs_timestamp ON system_logs(timestamp);
CREATE INDEX IF NOT EXISTS idx_system_logs_component ON system_logs(component);
'''