"""Model training agent"""
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
import logging
import numpy as np
from app.agents.base import BaseAgent
from app.core.db.repositories.match_repository import MatchRepository
from app.core.db.repositories.model_repository import ModelRepository
from app.core.db.models.match import MatchStatus
from app.core.db.models.feature import Feature
from app.core.db.models.model import Model
from app.ml.models import MatchPredictor, prepare_target
from app.ml.features import FeatureBuilder
from app.core.config import settings
import os

logger = logging.getLogger(__name__)


class TrainingAgent(BaseAgent):
    """Agent for training ML models"""
    
    def __init__(self, db: Session):
        super().__init__(db, "TrainingAgent")
        self.match_repo = MatchRepository(db)
        self.model_repo = ModelRepository(db)
        self.feature_builder = FeatureBuilder(db)
        self.models_dir = "models"
        os.makedirs(self.models_dir, exist_ok=True)
    
    def run(self) -> bool:
        """Run model training"""
        run_id = self.start_run()
        
        try:
            # Get finished matches with features for training
            min_date = datetime.now() - timedelta(days=365 * 2)
            max_date = None  # do not cut off 'today' finished matches
            
            matches = self.match_repo.get_matches_for_training(
                min_date=min_date,
                max_date=max_date
            )
            
            logger.info(f"Found {len(matches)} matches for training")
            
            if len(matches) < settings.min_matches_for_training:
                error_msg = f"Not enough matches for training: {len(matches)} < {settings.min_matches_for_training}"
                logger.warning(error_msg)
                # Skip training (not a hard failure)
                self.finish_run(success=True)
                return True
            
            # Get features for matches
            match_ids = [m.id for m in matches]
            features = self.db.query(Feature).filter(
                Feature.match_id.in_(match_ids)
            ).all()
            
            feature_dict = {f.match_id: f for f in features}
            
            # Prepare data
            X = []
            y = []
            match_dates = []
            
            for match in matches:
                if match.id not in feature_dict:
                    continue
                if match.home_score is None or match.away_score is None:
                    continue
                
                feature = feature_dict[match.id]
                feature_vector = self.feature_builder.build_feature_vector(feature)
                X.append(feature_vector)
                y.append(prepare_target(match))
                match_dates.append(match.match_date)
            
            if len(X) < settings.min_matches_for_training:
                error_msg = f"Not enough matches with features: {len(X)} < {settings.min_matches_for_training}"
                logger.warning(error_msg)
                # Skip training (not a hard failure)
                self.finish_run(success=True)
                return True
            
            X = np.array(X)
            y = np.array(y)
            
            # Time-based split
            split_idx = int(len(X) * settings.train_test_split)
            X_train = X[:split_idx]
            y_train = y[:split_idx]
            X_val = X[split_idx:]
            y_val = y[split_idx:]
            
            logger.info(f"Training on {len(X_train)} samples, validating on {len(X_val)} samples")
            
            # Train model
            predictor = MatchPredictor(model_type=settings.model_type)
            metrics = predictor.train(X_train, y_train, X_val, y_val)
            
            logger.info(f"Training metrics: {metrics}")
            
            # Save model
            version = f"v{datetime.now().strftime('%Y%m%d_%H%M%S')}"
            artifact_path = os.path.join(self.models_dir, f"model_{version}.joblib")
            predictor.save(artifact_path)
            
            # Create model record
            model = Model(
                version=version,
                model_type=settings.model_type,
                parameters={"rolling_window": settings.rolling_window_matches},
                train_accuracy=metrics.get("train_accuracy"),
                test_accuracy=metrics.get("test_accuracy"),
                train_logloss=metrics.get("train_logloss"),
                test_logloss=metrics.get("test_logloss"),
                train_brier=metrics.get("train_brier"),
                test_brier=metrics.get("test_brier"),
                artifact_path=artifact_path,
                trained_on_matches=len(X),
                is_active=1
            )
            
            # Deactivate old models
            old_models = self.model_repo.get_all()
            for old_model in old_models:
                old_model.is_active = 0
            
            model = self.model_repo.create(model)
            self.db.commit()
            
            logger.info(f"Model {version} trained and saved")
            
            self.finish_run(success=True)
            return True
            
        except Exception as e:
            error_msg = str(e)
            logger.error(f"Training failed: {error_msg}", exc_info=True)
            self.finish_run(success=False, error=error_msg)
            return False

