"""Prediction agent"""
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
import logging
import numpy as np
from app.agents.base import BaseAgent
from app.core.db.repositories.match_repository import MatchRepository
from app.core.db.repositories.prediction_repository import PredictionRepository
from app.core.db.repositories.model_repository import ModelRepository
from app.core.db.models.match import MatchStatus
from app.core.db.models.feature import Feature
from app.ml.models import MatchPredictor
from app.ml.features import FeatureBuilder

logger = logging.getLogger(__name__)


class PredictionAgent(BaseAgent):
    """Agent for making predictions"""
    
    def __init__(self, db: Session):
        super().__init__(db, "PredictionAgent")
        self.match_repo = MatchRepository(db)
        self.prediction_repo = PredictionRepository(db)
        self.model_repo = ModelRepository(db)
        self.feature_builder = FeatureBuilder(db)
    
    def run(self) -> bool:
        """Run predictions"""
        run_id = self.start_run()
        predictions_made = 0
        
        try:
            # Get latest active model
            model_record = self.model_repo.get_latest()
            if not model_record:
                error_msg = "No active model found"
                logger.warning(error_msg)
                self.finish_run(success=False, error=error_msg)
                return False
            
            # Load model
            try:
                predictor = MatchPredictor.load(model_record.artifact_path)
            except Exception as e:
                error_msg = f"Failed to load model: {e}"
                logger.error(error_msg)
                self.finish_run(success=False, error=error_msg)
                return False
            
            # Get upcoming matches
            from_date = datetime.now()
            to_date = from_date + timedelta(days=7)
            
            matches = self.match_repo.get_upcoming(
                from_date=from_date,
                to_date=to_date,
                limit=500
            )
            
            logger.info(f"Making predictions for {len(matches)} matches")
            
            for match in matches:
                # Get or create features
                feature = self.db.query(Feature).filter(
                    Feature.match_id == match.id
                ).first()
                
                if not feature:
                    # Build features if missing
                    feature = self.feature_builder.build_features_for_match(match)
                    self.db.add(feature)
                    self.db.commit()
                
                # Build feature vector
                feature_vector = self.feature_builder.build_feature_vector(feature)
                X = np.array([feature_vector])
                
                # Predict (robust to 2-class trained model)
                proba = predictor.predict_proba(X)[0]
                classes = None
                try:
                    classes = list(getattr(predictor.model, "classes_", []))
                except Exception:
                    classes = None

                probs = {}
                if classes:
                    for idx, cls in enumerate(classes):
                        try:
                            probs[int(cls)] = float(proba[idx])
                        except Exception:
                            pass
                else:
                    probs = {
                        0: float(proba[0]) if len(proba) > 0 else 0.0,
                        1: float(proba[1]) if len(proba) > 1 else 0.0,
                        2: float(proba[2]) if len(proba) > 2 else 0.0,
                    }

                p_home = probs.get(0, 0.0)
                p_draw = probs.get(1, 0.0)
                p_away = probs.get(2, 0.0)

                s = p_home + p_draw + p_away
                if s > 0:
                    p_home /= s
                    p_draw /= s
                    p_away /= s

                # Update or create prediction
                self.prediction_repo.update_or_create(
                    match_id=match.id,
                    model_id=model_record.id,
                    prob_home_win=float(p_home),
                    prob_draw=float(p_draw),
                    prob_away_win=float(p_away),
                    confidence=float(max(p_home, p_draw, p_away))
                )
                
                predictions_made += 1
            
            self.db.commit()
            
            # Update run statistics
            run = self.run_repo.get_by_id(run_id)
            if run:
                run.records_created = predictions_made
                self.db.commit()
            
            self.finish_run(success=True)
            logger.info(f"Made {predictions_made} predictions")
            return True
            
        except Exception as e:
            error_msg = str(e)
            logger.error(f"Prediction failed: {error_msg}", exc_info=True)
            self.finish_run(success=False, error=error_msg)
            return False

