o
    \ib                     @   s  d Z ddlZddlmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZmZmZ ddlmZ ddlmZmZmZ dd	lmZmZmZ dd
lmZ ddlmZmZ ddl m!Z! ddl"m#Z#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z*m+Z+ ddl,m-Z- ddl.m/Z/ ddl0m1Z1m2Z2m3Z3 ddl4m5Z5 ddl6m7Z7 ddl8m9Z9 G dd deeZ:dd Z;ej<=de9dd Z>dd Z?dd  Z@d!d" ZAej<=de9d#d$ ZBd%d& ZCd'd( ZDd)d* ZEd+d, ZFd-d. ZGd/d0 ZHd1d2 ZId3d4 ZJej<=d5ed6d6gej<=d7ed8fed9fgd:d; ZKej<=d<d=eLfd>eMfd?d@ eMfgej<=dAeegdBdC ZNej<=dDdd8gdEdF ZOdGdH ZPdIdJ ZQej<=g dKdLd9dLeRg dMgdLd8dLeRg dNgdLd9dOeRg dPgdLd8dOeRg dQgdLd9dReRdLd9ggdLd8dReRg dSgdLd9d9eRdLd9ggdLd8d9eRdLd8ggd9d9dOeRd9ggd9d8dLeRd9d8ggd9d8dOeRd9d8gggdTdU ZSej<=dVeegdWdX ZTej<=dVeegdYdZ ZUej<=dVeegej<=d[eeegd\d] ZVd^d_ ZWej<=d`edafedbfgdcdd ZXdedf ZYdgdh ZZdidj Z[dS )kz'
Testing Recursive feature elimination
    N)
attrgetter)parallel_backend)assert_allcloseassert_array_almost_equalassert_array_equal)BaseEstimatorClassifierMixinis_classifier)TransformedTargetRegressor)CCAPLSCanonicalPLSRegression)	load_irismake_classificationmake_friedman1)RandomForestClassifier)RFERFECV)SimpleImputer)LinearRegressionLogisticRegression)
get_scorermake_scorerzero_one_loss)
GroupKFoldcross_val_score)make_pipeline)StandardScaler)SVCSVR	LinearSVR)check_random_state)ignore_warnings)CSR_CONTAINERSc                       sb   e Zd ZdZdddZdd Zdd ZeZeZeZ	dd
dZ
dddZdd Z fddZ  ZS )MockClassifierz@
    Dummy classifier to test recursive feature elimination
    r   c                 C   s
   || _ d S N	foo_param)selfr'    r)   /var/www/www-root/data/www/176.119.141.140/sports-predictor/venv/lib/python3.10/site-packages/sklearn/feature_selection/tests/test_rfe.py__init__$      
zMockClassifier.__init__c                 C   s>   t |t |ks
J tj|jd tjd| _tt|| _| S )N   )dtype)	lennponesshapefloat64coef_sortedsetclasses_r(   Xyr)   r)   r*   fit'   s   zMockClassifier.fitc                 C   s   t |jd S )Nr   )r0   r1   r2   )r(   Tr)   r)   r*   predict-   s   zMockClassifier.predictNc                 C      dS )Ng        r)   r8   r)   r)   r*   score4      zMockClassifier.scoreTc                 C   s
   d| j iS )Nr'   r&   )r(   deepr)   r)   r*   
get_params7   r,   zMockClassifier.get_paramsc                 K   s   | S r%   r)   )r(   paramsr)   r)   r*   
set_params:   r@   zMockClassifier.set_paramsc                    s   t   }d|j_|S )NT)super__sklearn_tags__
input_tags	allow_nan)r(   tags	__class__r)   r*   rF   =   s   
zMockClassifier.__sklearn_tags__)r   )NN)T)__name__
__module____qualname____doc__r+   r;   r=   predict_probadecision_function	transformr?   rB   rD   rF   __classcell__r)   r)   rJ   r*   r$      s    


r$   c                  C   s   t d} t }tj|j| jt|jdfdf }|j}td| dd}t	|ddd	}|
|| t|j|jd
 ks;J tdd}t	|ddd	}|
|| t| |  d S )Nr      size      )n_estimatorsrandom_state	max_depth   皙?	estimatorn_features_to_selectstepr-   linearkernel)r!   r   r0   c_datanormalr/   targetr   r   r;   ranking_r2   r   r   get_support)	generatoririsr9   r:   clfrfeclf_svcrfe_svcr)   r)   r*   test_rfe_features_importanceC   s   "
rq   csr_containerc                 C   s6  t d}t }tj|j|jt|jdfdf }| |}|j}tdd}t	|ddd}|
|| ||}|
|| t|j|jd	 ksHJ tdd}	t	|	ddd}
|

|| |
|}|j|jjkshJ t|d d
 |jd d
  t||||j |||||j|jksJ t||  d S )Nr   rT   rU   rb   rc   r\   r]   r^   r-   
   )r!   r   r0   re   rf   rg   r/   rh   r   r   r;   rR   ri   r2   r   r=   r?   toarray)rr   rk   rl   r9   X_sparser:   rm   rn   X_r
clf_sparse
rfe_sparse
X_r_sparser)   r)   r*   test_rfeX   s(   "



 rz   c                  C   s   G dd dt t} tdd\}}tjtdd t|  d|| W d    n1 s,w   Y  tjtdd t|  dj||d	d
|| W d    n1 sSw   Y  t|  dj||d	d
j||d	d
 d S )Nc                   @   s    e Zd ZdddZdddZdS )z0test_RFE_fit_score_params.<locals>.TestEstimatorNc                 S   s2   |d u rt dtdd||| _| jj| _| S )Nfit: prop cannot be Nonerb   rc   )
ValueErrorr   r;   svc_r4   r(   r9   r:   propr)   r)   r*   r;   |   s
   
z4test_RFE_fit_score_params.<locals>.TestEstimator.fitc                 S   s   |d u rt d| j||S )Nscore: prop cannot be None)r|   r}   r?   r~   r)   r)   r*   r?      s   z6test_RFE_fit_score_params.<locals>.TestEstimator.scorer%   )rL   rM   rN   r;   r?   r)   r)   r)   r*   TestEstimator{   s    
r   T
return_X_yr{   matchr_   r   foo)r   )	r   r   r   pytestraisesr|   r   r;   r?   )r   r9   r:   r)   r)   r*   test_RFE_fit_score_paramsx   s   "(r   c                  C   s   t d} t }tj|j| jt|jdfdf }|j}tdd}t	|ddd}|
|| t	|d	dd}|
|| t|j|j t|j|j d S )
Nr   rT   rU   rb   rc   r\   r]   r^   g?)r!   r   r0   re   rf   rg   r/   rh   r   r   r;   r   ri   support_)rk   rl   r9   r:   rm   rfe_numrfe_percr)   r)   r*   test_rfe_percent_n_features   s   "
r   c                  C   s   t d} t }tj|j| jt|jdfdf }|j}t }t	|ddd}|
|| ||}|
|| t|j|jd ksBJ |j|jjksKJ d S )Nr   rT   rU   r\   r]   r^   r-   )r!   r   r0   re   rf   rg   r/   rh   r$   r   r;   rR   ri   r2   )rk   rl   r9   r:   rm   rn   rv   r)   r)   r*   test_rfe_mockclassifier   s   "
r   c                 C   s  t d}t }tj|j|jt|jdfdf }t|j}t	t
dddd}||| |j D ]}t|j| |jd ksAJ q1t|j|jd ksNJ ||}t||j t	t
dddd}| |}	||	| ||	}
t|
 |j ttdd	}t	t
ddd|d
}t|j|| ||}t||j td}t	t
ddd|d
}||| ||}t||j dd }t	t
ddd|d
}||| |jdksJ t	t
dddd}||| |j D ]}t|j| dksJ qt|j|jd ksJ ||}t||j t	t
dddd}| |}	||	| ||	}
t|
 |j t	t
dddd}| |}	||	| ||	}
t|
 |j d S )Nr   rT   rU   rb   rc   r-   r_   ra   F)greater_is_better)r_   ra   scoringaccuracyc                 S   r>   )Ng      ?r)   )r_   r9   r:   r)   r)   r*   test_scorer   r@   ztest_rfecv.<locals>.test_scorerrX   皙?)r!   r   r0   re   rf   rg   r/   listrh   r   r   r;   cv_results_keysr2   ri   rR   r   rt   r   r   r"   r   n_features_)rr   rk   rl   r9   r:   rfecvkeyrv   rfecv_sparseru   ry   r   scorerr   r)   r)   r*   
test_rfecv   s^   "







r   c                  C   s   t d} t }tj|j| jt|jdfdf }t|j}t	t
 dd}||| |j D ]}t|j| |jd ks?J q/t|j|jd ksLJ d S )Nr   rT   rU   r-   r   )r!   r   r0   re   rf   rg   r/   r   rh   r   r$   r;   r   r   r2   ri   )rk   rl   r9   r:   r   r   r)   r)   r*   test_rfecv_mockclassifier	  s   "
r   c                  C   s   dd l } ddlm} | | _td}t }tj|j|j	t
|jdfdf }t|j}ttddddd}||| | j}|d t
| dksMJ d S )	Nr   )StringIOrT   rU   rb   rc   r-   )r_   ra   verbose)sysior   stdoutr!   r   r0   re   rf   rg   r/   r   rh   r   r   r;   seekreadline)r   r   rk   rl   r9   r:   r   verbose_outputr)   r)   r*   test_rfecv_verbose_output  s   "

r   c           
      C   s   t | }t }tj|j|jt|jdfdf }t|j}ddgddgddgfD ]F\}}t	t
 ||d}||| t|jd | | d }|j D ]}	t|j|	 |ksZJ qMt|j|jd ksgJ |j|ksnJ q(d S )NrT   rU   rX   r-      r_   ra   min_features_to_select)r!   r   r0   re   rf   rg   r/   r   rh   r   r$   r;   ceilr2   r   r   ri   r   )
global_random_seedrk   rl   r9   r:   ra   r   r   	score_lenr   r)   r)   r*   test_rfecv_cv_results_size.  s"   "
r   c                  C   sD   t tdd} t| sJ t }t| |j|j}| dks J d S )Nrb   rc   gffffff?)r   r   r	   r   r   rf   rh   min)rn   rl   r?   r)   r)   r*   test_rfe_estimator_tagsG  s
   r   c                 C   s   d}t d|| d\}}|j\}}tdd}t|dd}|||}|j |d ks,J t|d	d}|||}|j |d ksCJ t|d
d}|||}|j |d ksZJ d S )Nrs   2   	n_samples
n_featuresrZ   rb   rc   g{Gz?ra   rX   r      )r   r2   r   r   r;   r   sum)r   r   r9   r:   r   r_   selectorselr)   r)   r*   test_rfe_min_stepP  s   


r   c                 C   sz  dd }dd }ddg}ddg}ddg}t |||D ]D\}}}t| }	|	jd|fd	}
|	d }ttd
d||d}||
| t	|j
||||ksPJ t	|j
||||ks^J qd}ddg}ddg}t ||D ]L\}}t| }	|	jd|fd	}
|	d }ttd
d|d}||
| |j D ] }t|j| ||||ksJ t|j| ||||ksJ qqnd S )Nc                 S   s   d| | | d |  S Nr-   r)   r   r`   ra   r)   r)   r*   formula1q  s   z4test_number_of_subsets_of_features.<locals>.formula1c                 S   s   dt | | t|  S r   )r0   r   floatr   r)   r)   r*   formula2t  s   z4test_number_of_subsets_of_features.<locals>.formula2   r   rX   d   rU   rb   rc   r^   r-   rs   r   )zipr!   rg   randroundr   r   r;   r0   maxri   r   r   r   r/   )r   r   r   n_features_listn_features_to_select_list	step_listr   r`   ra   rk   r9   r:   rn   r   r   r)   r)   r*   "test_number_of_subsets_of_featuresh  sJ   	
r   c           	      C   s   t | }t }tj|j|jt|jdfdf }|j}tt	ddd}|
|| |j}|j}|jdd |
|| t|j| | |j ksLJ | D ]}|| t|j| ks`J qPd S )NrT   rU   rb   rc   r   rX   )n_jobs)r!   r   r0   re   rf   rg   r/   rh   r   r   r;   ri   r   rD   r   r   r   approx)	r   rk   rl   r9   r:   r   rfecv_rankingrfecv_cv_results_r   r)   r)   r*   test_rfe_cv_n_jobs  s   "r   c                  C   s   t d} t }d}ttd|t|j}|j}|jdkt	}t
t| dddtddd}|j|||d	 |jdks>J d S )
Nr   r\   rZ   r-   r   rX   )n_splits)r_   ra   r   cv)groups)r!   r   r0   floorlinspacer/   rh   rf   astypeintr   r   r   r;   r   )rk   rl   number_groupsr   r9   r:   
est_groupsr)   r)   r*   test_rfe_cv_groups  s   r   importance_getterzregressor_.coef_zselector, expected_n_featuresr   r\   c                 C   s\   t dddd\}}tdd}t|tjtjd}||| d}|||}|j |ks,J d S )Nr   rs   r   r   r   	regressorfuncinverse_funcr   )	r   r    r
   r0   logexpr;   r   r   )r   r   expected_n_featuresr9   r:   r_   log_estimatorr   r)   r)   r*   test_rfe_wrapped_estimator  s   

r   zimportance_getter, err_typeautorandomc                 C   s   | j S r%   )
importance)xr)   r)   r*   <lambda>  s    r   Selectorc                 C   sr   t dddd\}}t }t|tjtjd}t| ||| d}||| W d    d S 1 s2w   Y  d S )Nr   rs   *   r   r   r   )	r   r    r
   r0   r   r   r   r   r;   )r   err_typer   r9   r:   r_   r   modelr)   r)   r*   %test_rfe_importance_getter_validation  s   

"r   r   c                 C   sn   t  }|j}|j}tj|d d< tj|d d< t }| d ur%t|| d}nt|d}|	|| |
| d S )Nr   r-   )r_   r   r   )r   rf   rh   r0   naninfr$   r   r   r;   rR   )r   rl   r9   r:   rm   rn   r)   r)   r*   test_rfe_allow_nan_inf_in_x  s   
r   c                  C   sR   t t t } tdd\}}t| ddd}||| ||jd dks'J d S )NTr   rX   $named_steps.logisticregression.coef_)r`   r   r-   )r   r   r   r   r   r;   rR   r2   )pipelinerf   r:   sfmr)   r)   r*   test_w_pipeline_2d_coef_  s   r   c           	         s   t | }t }tj|j|jt|jdfdf }|j}tt	ddd  
|| dd  j D }t fdd|D }tj|d	d
}tj|d	d
}t jd | t jd | d S )NrT   rU   rb   rc   r   c                 S   s   g | ]
}t d |r|qS )zsplit\d+_test_score)research.0r   r)   r)   r*   
<listcomp>!  s    
z+test_rfecv_std_and_mean.<locals>.<listcomp>c                    s   g | ]} j | qS r)   )r   r   r   r)   r*   r   &  s    r   axismean_test_scorestd_test_score)r!   r   r0   re   rf   rg   r/   rh   r   r   r;   r   r   asarraymeanstdr   )	r   rk   rl   r9   r:   
split_keys	cv_scoresexpected_meanexpected_stdr)   r   r*   test_rfecv_std_and_mean  s   "r  )r   r   ra   cv_results_n_featuresr-   )r-   rX   r   r\   )r-   rX   r   r\   r   rX   )r-   rX   r\   )r-   r   r   r   )r-   rX   r   c                    sh   t d||dd\}}ttdd|| d  || t jd | t fdd	 j D s2J d S )
NrW   r   )r   r   n_informativen_redundantrb   rc   r   r   c                 3   s&    | ]}t |t  jd  kV  qdS )r   N)r/   r   )r   valuer   r)   r*   	<genexpr>N  s
    
z3test_rfecv_cv_results_n_features.<locals>.<genexpr>)r   r   r   r;   r   r   allvalues)r   r   ra   r	  r9   r:   r)   r   r*    test_rfecv_cv_results_n_features.  s   
r  ClsRFEc                 C   s@   t jjdd}t jjddd}tdd}| |}||| d S )N)rs   r   rU   rX   )rs   rX   r   )rY   )r0   r   rg   randintr   r;   )r  r9   r:   rm   rfe_testr)   r)   r*   test_multioutputT  s
   
r  c                 C   sF   t dd\}}tj|d< tt t t }| |dd}||| dS )z`Check that RFE works with pipeline that accept nans.

    Non-regression test for gh-21743.
    Tr   )r   r   r   )r_   r   N)r   r0   r   r   r   r   r   r;   )r  r9   r:   pipefsr)   r)   r*   test_pipeline_with_nans]  s   
r  PLSEstimatorc                 C   sH   t dddd\}}|dd}| |dd||}|||dks"J d	S )
zCheck the behaviour of RFE with PLS estimators.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/12410
    r   rs   r   r   r-   )n_componentsr   g      ?N)r   r;   r?   )r  r  r9   r:   r_   r   r)   r)   r*   test_rfe_plss  s   
r  c                  C   s   t  } tt d}d}d}tjt|d}|| j| j	| j W d   n1 s+w   Y  t
|jjts9J |t|jjv sCJ dS )a  Check that we raise the proper AttributeError when the estimator
    does not implement the `decision_function` method, which is decorated with
    `available_if`.

    Non-regression test for:
    https://github.com/scikit-learn/scikit-learn/issues/28108
    r   z/This 'RFE' has no attribute 'decision_function'z>'LinearRegression' object has no attribute 'decision_function'r   N)r   r   r   r   r   AttributeErrorr;   rf   rh   rQ   
isinstancer  	__cause__str)rl   rn   	outer_msg	inner_msg	exec_infor)   r)   r*   "test_rfe_estimator_attribute_error  s   r"  zClsRFE, paramr`   r   c                 C   sn   t ddd\}}tjt| dd | d	dt i|di}||| W d   dS 1 s0w   Y  dS )
zCheck if the correct warning is raised when trying to initialize a RFE
    object with a n_features_to_select attribute larger than the number of
    features present in the X variable that is passed to the fit method
    rW   r   )r   rZ   z=21 > n_features=20r   r_      Nr)   )r   r   warnsUserWarningr   r;   )r  paramr9   r:   clsrfer)   r)   r*   %test_rfe_n_features_to_select_warning  s
   "r(  c                  C   s   t dd\} }| jd }t|}d|d|d < tj| | d|d  gdd}t||d|d  g}tdd}t|dd	}|j| ||d
 t|dd	}||| t|j	|j	 t|dd	}	t|}
|	j| ||
d
 t
|	j	|j	rxJ dS )z4Test that `RFE` works correctly with sample weights.r   r   rX   Nr   rb   rc   r]   r   )sample_weight)r   r2   r0   	ones_likeconcatenater   r   r;   r   ri   array_equal)r9   r:   r   r)  X2y2r_   rfe_swrn   rfe_sw_2sample_weight_2r)   r)   r*   test_rfe_with_sample_weight  s    



r2  c                 C   sv   t | d\}}t }t|dd}||| |j}td ||| W d    n1 s.w   Y  t||j d S )Nr   rX   )r_   r   	threading)r   r   r   r;   ri   r   r   )r   r9   r:   rm   rn   ranking_refr)   r)   r*   &test_rfe_with_joblib_threading_backend  s   
r5  c                 C   s   t | d\}}t }t|ddd}||| t|jd t|jd ks'J t|jd t|jd ks7J t|jd	 t|jd
 ksGJ dS )zx
    Test that the results of RFECV are consistent across the different folds
    in terms of length of the arrays.
    r   rX   r   )r_   r   r   split1_test_scoresplit2_test_scoresplit1_supportsplit2_supportsplit1_rankingsplit2_rankingN)r   r   r   r;   r/   r   )r   r9   r:   rm   r   r)   r)   r*   test_results_per_cv_in_rfecv  s"   

r<  )\rO   r   operatorr   numpyr0   r   joblibr   numpy.testingr   r   r   sklearn.baser   r   r	   sklearn.composer
   sklearn.cross_decompositionr   r   r   sklearn.datasetsr   r   r   sklearn.ensembler   sklearn.feature_selectionr   r   sklearn.imputer   sklearn.linear_modelr   r   sklearn.metricsr   r   r   sklearn.model_selectionr   r   sklearn.pipeliner   sklearn.preprocessingr   sklearn.svmr   r   r    sklearn.utilsr!   sklearn.utils._testingr"   sklearn.utils.fixesr#   r$   rq   markparametrizerz   r   r   r   r   r   r   r   r   r   r   r   r   r   r|   r  r   r   r   r  arrayr  r  r  r  r"  r(  r2  r5  r<  r)   r)   r)   r*   <module>   s    $

Q	A





!