"""ML model training and prediction"""
from typing import List, Tuple, Optional, Dict
import numpy as np
import pandas as pd
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import log_loss, brier_score_loss, accuracy_score
import joblib
import os
from datetime import datetime
import logging

logger = logging.getLogger(__name__)


class MatchPredictor:
    """ML model for match prediction"""
    
    def __init__(self, model_type: str = "gradient_boosting"):
        self.model_type = model_type
        self.model = None
        self.scaler = StandardScaler()
        self.is_fitted = False
    
    def _create_model(self):
        """Create model based on type"""
        if self.model_type == "gradient_boosting":
            return GradientBoostingClassifier(
                n_estimators=100,
                learning_rate=0.1,
                max_depth=5,
                random_state=42
            )
        elif self.model_type == "logistic_regression":
            return LogisticRegression(
                multi_class="multinomial",
                max_iter=1000,
                random_state=42
            )
        else:
            raise ValueError(f"Unknown model type: {self.model_type}")
    
    def train(
        self,
        X_train: np.ndarray,
        y_train: np.ndarray,
        X_val: Optional[np.ndarray] = None,
        y_val: Optional[np.ndarray] = None
    ) -> Dict[str, float]:
        """Train the model"""
        self.model = self._create_model()
        
        # Scale features
        X_train_scaled = self.scaler.fit_transform(X_train)
        
        # Train
        self.model.fit(X_train_scaled, y_train)
        self.is_fitted = True
        
        # Calculate metrics
        metrics = {}
        
        # Training metrics
        y_train_pred = self.model.predict(X_train_scaled)
        y_train_proba = self.model.predict_proba(X_train_scaled)
        
        metrics["train_accuracy"] = float(accuracy_score(y_train, y_train_pred))
        # Ensure stable logloss even if y_* has only one class
        all_labels = sorted(set(list(y_train) + list(y_val)))
        if len(all_labels) < 2:
            all_labels = [0, 1, 2]

        metrics["train_logloss"] = float(log_loss(y_train, y_train_proba, labels=all_labels))
        metrics["train_brier"] = float(brier_score_loss(
            y_train == 0, y_train_proba[:, 0]
        ))
        
        # Validation metrics
        if X_val is not None and y_val is not None:
            X_val_scaled = self.scaler.transform(X_val)
            y_val_pred = self.model.predict(X_val_scaled)
            y_val_proba = self.model.predict_proba(X_val_scaled)

        # Ensure stable logloss even if y_val has only one class
            metrics["test_accuracy"] = float(accuracy_score(y_val, y_val_pred))
            metrics["test_logloss"] = float(log_loss(y_val, y_val_proba, labels=all_labels))
            metrics["test_brier"] = float(brier_score_loss(
                y_val == 0, y_val_proba[:, 0]
            ))
        
        return metrics
    
    def predict(self, X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Predict outcomes and probabilities"""
        if not self.is_fitted:
            raise ValueError("Model not fitted")
        
        X_scaled = self.scaler.transform(X)
        predictions = self.model.predict(X_scaled)
        probabilities = self.model.predict_proba(X_scaled)
        
        return predictions, probabilities
    
    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Predict probabilities only"""
        if not self.is_fitted:
            raise ValueError("Model not fitted")
        
        X_scaled = self.scaler.transform(X)
        return self.model.predict_proba(X_scaled)
    
    def save(self, filepath: str):
        """Save model to file"""
        if not self.is_fitted:
            raise ValueError("Model not fitted")
        
        os.makedirs(os.path.dirname(filepath), exist_ok=True)
        joblib.dump({
            "model": self.model,
            "scaler": self.scaler,
            "model_type": self.model_type
        }, filepath)
        logger.info(f"Model saved to {filepath}")
    
    @staticmethod
    def load(filepath: str) -> "MatchPredictor":
        """Load model from file"""
        data = joblib.load(filepath)
        predictor = MatchPredictor(model_type=data["model_type"])
        predictor.model = data["model"]
        predictor.scaler = data["scaler"]
        predictor.is_fitted = True
        return predictor


def prepare_target(match) -> int:
    """Prepare target variable from match result (0=home win, 1=draw, 2=away win)"""
    if match.home_score > match.away_score:
        return 0
    elif match.home_score == match.away_score:
        return 1
    else:
        return 2

