import math
from typing import Optional
from sqlalchemy.orm import Session
from fastapi import HTTPException
import logging

from app.common.models import User

logger = logging.getLogger(__name__)

class BaseRepo:
    """Base repo class that provides common functionality for all repos"""
    
    def __init__(self, db: Session, current_user = None):
        self.db = db
        self.current_user = current_user

    def handle_db_error(self, error: Exception, operation: str) -> None:
        """
        Handle database errors in a consistent way
        
        Args:
            error: The exception that was raised
            operation: Description of the operation that failed
        """
        logger.error(f"Database error during {operation}: {str(error)}")
        raise HTTPException(status_code=500, detail=f"Database error occurred during {operation}")
    
    def get_by_field(self, model_class: any, field: str, value: str) -> Optional[any]:
        """
        Generic method to get an entity by a field
        """
        try:
            return self.db.query(model_class).filter(getattr(model_class, field) == value).first()
        except Exception as e:
            self.handle_db_error(e, f"getting {model_class.__name__} by {field}")

    def get_all_paginated(self, model_class: any, page: int, page_size: int, order_by: str = "created_at", order_direction: str = "desc", filter_condition = None) -> Optional[any]:
        """
        Generic method to get all entities with pagination
        
        Args:
            model_class: The SQLAlchemy model class
            page: Page number (1-based)
            page_size: Number of items per page
            order_by: Column to order by
            order_direction: 'asc' or 'desc'
            filter_condition: SQLAlchemy filter condition
        """
        offset = (page - 1) * page_size
        query = self.db.query(model_class)
        
        if filter_condition is not None:
            query = query.filter(filter_condition)
            
        if order_direction == "desc":
            query = query.order_by(getattr(model_class, order_by).desc())
        else:
            query = query.order_by(getattr(model_class, order_by).asc())
            
        entities = query.offset(offset).limit(page_size).all()
        
        total_entities = query.count()
        total_pages = math.ceil(total_entities / page_size)
        
        return entities, total_pages

    def get_by_id(self, model_class: any, id: int) -> Optional[any]:
        """
        Generic method to get an entity by ID
        
        Args:
            model_class: The SQLAlchemy model class
            id: The ID to look up
            
        Returns:
            The found entity or None
        """
        try:
            return self.get_by_field(model_class, "id", id)
        except Exception as e:
            self.handle_db_error(e, f"getting {model_class.__name__} by ID")

    def create(self, model_class: any, data: dict) -> any:
        """
        Generic method to create an entity
        
        Args:
            model_class: The SQLAlchemy model class
            data: Dict containing the entity data
            
        Returns:
            The created entity
        """
        try:
            entity = model_class(**data)
            self.db.add(entity)
            self.db.commit()
            self.db.refresh(entity)
            return entity
        except Exception as e:
            self.db.rollback()
            self.handle_db_error(e, f"creating {model_class.__name__}")
