o
    \i                    @   sb  d dl Z d dlmZ d dlZd dlZd dlZd dlmZ	 d dl
mZmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZ d dlm Z m!Z!m"Z"m#Z# d dl$m%Z% d dl&m'Z' d dl(m)Z)m*Z*m+Z+m,Z, dd Z-G dd dej.Z/G dd dej0Z1G dd dej2Z3dd Z.dd Z0dd Z2dd Z4dd  Z5d!d" Z6e7d#d$gd$d$gd$d#gd%d%gd%d&gd&d%ggZ8g d'Z9e7d$d$gd&d&gd(d&ggZ:g d)Z;e7d$d%gd*d+gd,d-gd%d%gd.d+gd-d-gd$d$gd d/gd%d$gg	Z<d0gd( d1gd(  d2gd(  Z=e7d,d+gd%d&gd d#ggZ>g d3Z?e7g d4g d4g d5g d5g d6g d6g d7g d7gZ@e7g d8ZAe7g d9g d:g d;g d<g d=g d>g d?g d@gZBe7g d8ZCeD ZEe7d#d$gd$d$gd$d#gd%d%gd%d&gd&d%ggZFg d'ZGg dAZHddCdDZIdEdF ZJejKLdGe.e4e0e5gejKLdHg dIdJdK ZMejKLdGe.e4e0e5gdLdM ZNejKLdGe.e4e0e5gdNdO ZOejKLdGe.e4e0e5e2e6gdPdQ ZPejKLdGe.e4e0e5e2e6gdRdS ZQejKLdGe.e4e0e5gdTdU ZRejKLdGe.e4e0e5gdVdW ZSejKLdGe.e4e0e5gdXdY ZTejKLdGe.e4e0e5gdZd[ ZUejKLdGe.e4e0e5gd\d] ZVejKLdGe.e4e0e5gd^d_ ZWejKLd`e.e0gejKLdag dbdcdd ZXejKLd`e.e4e0e5gdedf ZYejKLdGe.e4gdgdh ZZejKLdGe.e4e2e6gdidj Z[ejKLdke.dle\dmife4dle\dmife2dne\dmife6dne\dmifgdodp Z]ejKLdGe.e4e0e5gdqdr Z^ejKLdke.dld ife4dld ife2dnd ife6dnd ifgdsdt Z_ejKLdGe.e4gdudv Z`ejKLdGe.e4gdwdx ZaejKLdGe.e4gdydz ZbejKLdGe.e4gd{d| ZcejKLdGe.e4gd}d~ ZdejKLdGe.e4gdd ZeejKLdGe.e4gdd ZfejKLdGe.e4gdd ZgejKLdGe.e4gdd ZhejKLdGe.e4gdd ZiejKLdGe.e4gdd ZjejKLdGe.e4gdd ZkejKLdGe.e4gdd ZlejKLdGe.e4gdd ZmejKLdGe.e4gdd ZnejKLdGe.e4gdd ZoejKLdGe.e4gdd ZpejKLdGe.e4gdd ZqejKLdGe.e4e2e6gdd ZrejKLdGe.e4gdd ZsejKLdGe.e4gdd ZtejKLdGe.e4gdd ZuejKLdGe.e4gdd ZvejKLdGe.e4gdd ZwejKLdGe.e4gejKLdHg dIdd ZxejKLdGe.e4gdd ZyejKLdGe.e4gdd ZzejKLdGe.e4gdd Z{ejKLdGe0e5gdd Z|ejKLdGe0e5gdd Z}ejKLdGe0e5gdd Z~ejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gdd ZejKLdGe0e5gejKLdHg dIdd ZejKLdGe0e5gdd ZdddĄZejKLdGe2e6gddƄ ZejKLdGe2e6gejKLdHg dIddȄ ZejKLdGe2e6gddʄ ZejKLdGe2e6gdd̄ ZejKLdGe2e6gejKLdHg dIdd΄ ZejKLdGe2e6gddЄ ZejKLdGe2e6gdd҄ ZejKLdGe2e6gddԄ ZejKLdGe2e6gddք Zdd؄ Zddڄ Zdd܄ Zddބ Zdd ZejKLdg ddd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd Zdd ZejKLdg ddd ZejKLd`ej.ej0gdd Zdd Zdd  ZejKLd`e.e0gdd ZejKLde.e4e0e5e2e6gejKLdejejfdd ZejKLde.e4e0e5e2e6gdd Zd	d
 ZdS (      N)Mock)datasetslinear_modelmetrics)cloneis_classifier)ConvergenceWarning)Nystroem)	_sgd_fast)_stochastic_gradient)RandomizedSearchCVShuffleSplitStratifiedShuffleSplit)make_pipeline)LabelEncoderMinMaxScalerStandardScalerscale)OneClassSVM)get_tags)assert_allcloseassert_almost_equalassert_array_almost_equalassert_array_equalc                 C   s8   d| vrd| d< d| vrd | d< d| vrd| d< d S d S )Nrandom_state*   tolmax_iter    kwargsr   r   /var/www/www-root/data/www/176.119.141.140/sports-predictor/venv/lib/python3.10/site-packages/sklearn/linear_model/tests/test_sgd.py_update_kwargs    s   r#   c                       s@   e Zd Z fddZ fddZ fddZ fddZ  ZS )	_SparseSGDClassifierc                    &   t |}t j||g|R i |S N)sp
csr_matrixsuperfitselfXyargskw	__class__r   r"   r*   +      
z_SparseSGDClassifier.fitc                    r%   r&   )r'   r(   r)   partial_fitr+   r1   r   r"   r4   /   r3   z _SparseSGDClassifier.partial_fitc                       t |}t |S r&   )r'   r(   r)   decision_functionr,   r-   r1   r   r"   r6   3      
z&_SparseSGDClassifier.decision_functionc                    r5   r&   )r'   r(   r)   predict_probar7   r1   r   r"   r9   7   r8   z"_SparseSGDClassifier.predict_proba)__name__
__module____qualname__r*   r4   r6   r9   __classcell__r   r   r1   r"   r$   *   s
    r$   c                   @   $   e Zd Zdd Zdd Zdd ZdS )_SparseSGDRegressorc                 O   (   t |}tjj| ||g|R i |S r&   )r'   r(   r   SGDRegressorr*   r+   r   r   r"   r*   =      
z_SparseSGDRegressor.fitc                 O   r@   r&   )r'   r(   r   rA   r4   r+   r   r   r"   r4   A   rB   z_SparseSGDRegressor.partial_fitc                 O   &   t |}tjj| |g|R i |S r&   )r'   r(   r   rA   r6   r,   r-   r/   r0   r   r   r"   r6   E   s   
z%_SparseSGDRegressor.decision_functionNr:   r;   r<   r*   r4   r6   r   r   r   r"   r?   <       r?   c                   @   r>   )_SparseSGDOneClassSVMc                 O   rC   r&   )r'   r(   r   SGDOneClassSVMr*   rD   r   r   r"   r*   L   r3   z_SparseSGDOneClassSVM.fitc                 O   rC   r&   )r'   r(   r   rH   r4   rD   r   r   r"   r4   P   r3   z!_SparseSGDOneClassSVM.partial_fitc                 O   rC   r&   )r'   r(   r   rH   r6   rD   r   r   r"   r6   T   r3   z'_SparseSGDOneClassSVM.decision_functionNrE   r   r   r   r"   rG   K   rF   rG   c                  K      t |  tjdi | S Nr   )r#   r   SGDClassifierr    r   r   r"   rK   Y      rK   c                  K   rI   rJ   )r#   r   rA   r    r   r   r"   rA   ^   rL   rA   c                  K   rI   rJ   )r#   r   rH   r    r   r   r"   rH   c   rL   rH   c                  K      t |  tdi | S rJ   )r#   r$   r    r   r   r"   SparseSGDClassifierh      rN   c                  K   rM   rJ   )r#   r?   r    r   r   r"   SparseSGDRegressorm   rO   rP   c                  K   rM   rJ   )r#   rG   r    r   r   r"   SparseSGDOneClassSVMr   rO   rQ         )rT   rT   rT   rU   rU   rU      )rT   rU   rU   g            ?g      g      ?g      ?      onetwothree)rY   rZ   r[   )rT   rT   r   r   r   r   )r   r   rT   r   r   r   )r   r   r   r   rT   rT   )r   r   r   rT   r   r   )rT   rT   rT   rT   rU   rU   rU   rU   )rT   ?皙?r   r   r   )rT   zG?g\(\?r   r   r   )rT   Q?g)\(?r   r   r   )rT   Q?Gz?r   r   r   )r   r   r   g{Gz?r`   rT   )r   r   r   gHzG?r^   rT   )r   r   r   r`   gffffff?rT   )r   r   r   g(\?rT   rT   )r   rT   rT           c                 C   s   |d u rt |jd }n|}t |jd }|}	d}
d}| ttfv r%d}t|D ]J\}}t ||}||	7 }|||  }|d||  9 }||| |  7 }|	||  | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
q)||
fS )NrT   rb         ?{Gz?)npzerosshaperN   rP   	enumeratedot)klassr-   r.   etaalphaweight_initintercept_initweightsaverage_weights	interceptaverage_interceptdecayientrypgradientr   r   r"   asgd   s.   rx   c                 C   s   | ddd|d}| || | ddd|d}|j |||j |j d | dddd|d}| || |j|jks<J t|j|j |jdd | || |j|jksWJ t|j|j d S )	Nrd   F)rl   eta0shufflelearning_rateMbP?	coef_initrn   T)rl   ry   rz   
warm_startr{   rl   )r*   coef_copy
intercept_t_r   
set_params)rj   r-   Ylrclfclf2clf3r   r   r"   _test_warm_start   s   
r   rj   r   )constantoptimal
invscalingadaptivec                 C   s   t | tt| d S r&   )r   r-   r   rj   r   r   r   r"   test_warm_start   s   r   c                 C   sz   | ddd}| tt ttd d tjf }tj||f }tt	 | t| W d    d S 1 s6w   Y  d S )Nrd   Frl   rz   )
r*   r-   r   re   arraynewaxisc_pytestraises
ValueError)rj   r   Y_r   r   r"   test_input_format   s   "r   c                 C   sV   | ddd}t |}|jdd |tt | ddd}|tt t|j|j d S )Nrd   l1)rl   penaltyl2)r   )r   r   r*   r-   r   r   r   rj   r   r   r   r   r"   
test_clone  s   r   c                 C   s   | ddd}| tt t|dsJ t|dsJ t|ds!J t|ds(J |  }| tt t|dr8J t|dr?J t|drFJ t|drMJ d S )NTrd   )averagery   _average_coef_average_intercept_standard_intercept_standard_coef)r*   r-   r   hasattrrj   r   r   r   r"   test_plain_has_no_average_attr  s   r   c                 C   s   | dd}|  }t dD ])}t|r)|jttttd |jttttd q|tt |tt qt|j|jdd | t	t
ttfv rRt|j|jdd d S | ttfv rat|j|j d S d S )NiX  r   d   classes   decimal)ranger   r4   r-   r   re   uniquer   r   rK   rN   rA   rP   r   r   rH   rQ   r   offset_)rj   clf1r   _r   r   r"   %test_late_onset_averaging_not_reached:  s   
r   c              	   C   s   d}d}t t}d||dk< d||dk< | ddd	||dd
d}| d
dd	||dd
d}|t| |t| t| t||||j |jd\}}t	|j | dd t
|j|dd d S )Nr|   -C6?      rT   rc   rU      r   squared_errorF)r   r{   lossry   rl   r   rz   )rm   rn   r   r   )re   r   r   r*   r-   rx   r   ravelr   r   r   )rj   ry   rl   Y_encoder   r   rp   rr   r   r   r"   !test_late_onset_averaging_reachedW  sH   
	


r   c                 C   sV   t jt jdk }t jt jdk }dD ]}d}| |d|d||}|j|k s(J qd S )Nr   TF  r|   )early_stoppingr   r   )irisdatatargetr*   n_iter_)rj   r-   r   r   r   r   r   r   r"   test_early_stopping  s   r   c                 C   sT   | ddddd}| tjtj | ddddd}| tjtj |j|jks(J d S )Nr   rd   r|   r   )r{   ry   r   r   r   )r*   r   r   r   r   )rj   r   r   r   r   r"   "test_adaptive_longer_than_constant  s
   r   c              
   C   s   t jt j}}d}d}d}d}| dtj||ddd ||d}||| |j|ks,J | dtj|ddd ||d	}t|rFt	||d
}	nt
||d
}	t|	||\}
}t|
}
|||
 ||
  |j|kslJ t|j|j d S )N皙?r   F
   Tr   rd   )r   r   validation_fractionr{   ry   r   r   rz   )r   r   r{   ry   r   r   rz   )	test_sizer   )r   r   r   re   randomRandomStater*   r   r   r   r   nextsplitsortr   r   )rj   r-   r   r   seedrz   r   r   r   cv	idx_trainidx_valr   r   r"   )test_validation_set_not_used_for_training  sD   




r   c                    sB   t jt j dD ] fdddD }t|t| q	d S )Nr   c                    s&   g | ]}|d dd  jqS )r   r   )r   n_iter_no_changer   r   )r*   r   ).0r   r-   r   r   rj   r   r"   
<listcomp>  s    	z)test_n_iter_no_change.<locals>.<listcomp>)rU   rV   r   )r   r   r   r   sorted)rj   n_iter_listr   r   r"   test_n_iter_no_change  s   	r   c                 C   sH   | ddd}t t |tt W d    d S 1 sw   Y  d S )NTra   )r   r   )r   r   r   r*   X3Y3r   r   r   r"   )test_not_enough_sample_for_early_stopping  s   "r   	Estimatorl1_ratio)r   gffffff?rT   c                 C   s>   | dddd tt}| d|dd tt}t|j|j dS )z@Check that l1_ratio is not used when penalty is not 'elasticnet'r   Nr   )r   r   r   )r*   r-   r   r   r   )r   r   r   r   r   r   r"   test_sgd_l1_ratio_not_used  s   r   c                 C   sL   | dd d}t jtdd |tt W d    d S 1 sw   Y  d S )N
elasticnet)r   r   z1l1_ratio must be set when penalty is 'elasticnet'matchr   r   r   r*   r-   r   )r   r   r   r   r"   #test_sgd_failing_penalty_validation  s   "r   c              	   C   s>   dD ]}| ddd|ddd}| tt t|tt qd S )N)hingesquared_hingelog_lossmodified_huberr   rd   Tr   )r   rl   fit_interceptr   r   rz   )r*   r-   r   r   predictTtrue_result)rj   r   r   r   r   r"   test_sgd_clf  s   r   c                 C   sL   t jtdd |  jtttdd W d   dS 1 sw   Y  dS )z1Check that the shape of `coef_init` is validated.z)Provided coef_init does not match datasetr   rV   r~   N)r   r   r   r*   r-   r   re   rf   rj   r   r   r"   test_provide_coef  s   "r   zklass, fit_paramsrn   r   offset_initc                 C   sN   |  }t jtdd |jttfi | W d   dS 1 s w   Y  dS )z:Check that `intercept_init` or `offset_init` is validated.zdoes not match datasetr   Nr   )rj   
fit_paramssgd_estimatorr   r   r"   test_set_intercept_offset  s   "r   c                 C   sJ   d}t jt|d | ddtt W d   dS 1 sw   Y  dS )zSCheck that we raise an error for `early_stopping` used with
    `partial_fit`.
    z/early_stopping should be False with partial_fitr   T)r   N)r   r   r   r4   r-   r   )rj   err_msgr   r   r"   (test_sgd_early_stopping_with_partial_fit-  s   "r   c                 C   s   |  j ttfi | dS )zdCheck that we can pass a scaler with binary classification to
    `intercept_init` or `offset_init`.N)r*   X5Y5)rj   r   r   r   r"    test_set_intercept_offset_binary9  s   r   c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}| dd||d	d
d	dd}t ||}	t |	}	|||	 t| ||	||\}
}|
d
d}
t	|j
|
dd t|j|dd d S )N皙?       @   r   r   sizer   r   TrT   Fr   r{   ry   rl   r   r   r   rz   rS      r   )re   r   r   normalri   signr*   rx   reshaper   r   r   r   )rj   rk   rl   	n_samples
n_featuresrngr-   wr   r.   rp   rr   r   r   r"   &test_average_binary_computed_correctlyH  s0   
r
  c                 C   sH   |   tt}|  j tt|jd |   tt}|  j tt|jd d S )Nrn   )r*   r   r   r   r-   r   r   r   r   r"   test_set_intercept_to_interceptj  s   r  c                 C   sN   | ddd}t t |ttd W d    d S 1 s w   Y  d S )Nrd   r   rl   r   	   )r   r   r   r*   X2re   onesr   r   r   r"   test_sgd_at_least_two_labelst  s   "r  c                 C   sT   d}t jt|d | ddjttttd W d    d S 1 s#w   Y  d S )Na`  class_weight 'balanced' is not supported for partial_fit\. In order to use 'balanced' weights, use compute_class_weight\('balanced', classes=classes, y=y\). In place of y you can use a large enough sample of the full training set target to properly estimate the class frequency distributions\. Pass the resulting weights as the class_weight parameter\.r   balanced)class_weightr   )r   r   r   r4   r-   r   re   r   )rj   regexr   r   r"   &test_partial_fit_weight_class_balanced|  s
   
"r  c                 C   sf   | ddd tt}|jjdksJ |jjdksJ |ddggjdks'J |t}t	|t
 d S )Nrd   r   r  rV   rU   r   r   rT   rV   r*   r  Y2r   rg   r   r6   r   T2r   true_result2rj   r   predr   r   r"   test_sgd_multiclass  s   
r  c              
   C   s   d}d}| dd||ddddd}t t}|t| t |}t|D ]0\}}t |jd	 }d
|||k< t	| t|||\}	}
t
|	|j| dd t|
|j| dd q$d S )Nr|   rd   r   r   TrT   Fr  r   rS   r   r   )re   r   r  r*   r  r   rh   r  rg   rx   r   r   r   r   )rj   rk   rl   r   np_Y2r   rt   cly_iaverage_coefrr   r   r   r"   test_sgd_multiclass_average  s,   

r#  c                 C   sb   | ddd}|j tttdtdd |jjdksJ |jjs%J d|t	}t
|t d S )Nrd   r   r  r  rV   r}   r   )r*   r  r  re   rf   r   rg   r   r   r  r   r  r  r   r   r"   "test_sgd_multiclass_with_init_coef  s   
r$  c                 C   sh   | dddd tt}|jjdksJ |jjdksJ |ddggjdks(J |t}t	|t
 d S )	Nrd   r   rU   )rl   r   n_jobsr  r   r   r  r  r  r   r   r"   test_sgd_multiclass_njobs  s   
r&  c                 C   s   |  }t t |jtttdd W d    n1 sw   Y  |  jtttdd}|  }t t |jtttdd W d    n1 sMw   Y  |  jtttdd}d S )N)rU   rU   r   r  rT   r  r   )r   r   r   r*   r  r  re   rf   r   r   r   r"   test_set_coef_multiclass  s   r(  c              	   C   s  t jjD ]}t|d}|dv rt|dsJ t|dsJ qd|}t|dr*J t|dr1J tjtdd}|j W d    n1 sFw   Y  t	|j
jtsTJ |t|j
jv s^J tjtdd}|j W d    n1 ssw   Y  t	|j
jtsJ |t|j
jv sJ qd S )	N)r   r   r   r9   predict_log_probaz5probability estimates are not available for loss={!r}z has no attribute 'predict_proba'r   z$has no attribute 'predict_log_proba')r   rK   loss_functionsr   formatr   r   AttributeErrorr9   
isinstancevalue	__cause__strr*  )rj   r   r   	inner_msg	exec_infor   r   r"   $test_sgd_predict_proba_method_access  s6   
r4  c              	   C   s  t dddd dtt}t|drJ t|drJ dD ]i}| |ddd}|tt |d	d
gg}|d dks;J |ddgg}|d dk sKJ tjdd, |d	d
gg}|d |d ksdJ |ddgg}|d |d k svJ W d    n1 sw   Y  q| ddddt	t
}|ddgddgg}|ddgddgg}ttj|ddtj|dd t|d  d t|d dksJ |ddgg}|ddgg}tt|d t|d  |d	d
gg}|d	d
gg}tt|| |ddgg}|ddgg}tt|| | dddd}|t	t
 |d	d
gg}|d	d
gg}| tkrMtj|ddtj|ddksLJ ntj|ddtj|ddks^J tjdd}||g}t|dk r||g}t|d dgd	  d S d S )Nr   rd   r   )r   rl   r   r   r9   r*  r)  )r   rl   r   rV   rU   r   rT   rW   rS   ignore)divide)r   r   r   r   皙333333?皙?rT   )axisr   r   gUUUUUU?)rK   r*   r-   r   r   r9   re   errstater*  r  r  r6   r   argmaxr   sumallargsortr   logrN   argminmean)rj   r   r   rv   dlpxr   r   r"   test_sgd_proba  sZ   
$"rG  c                 C   s   t t}tjd}t|}|| t|d d f }t| }| ddddd dd}||| t	|j
ddd	f td
 ||}t	|| |  t|j
sUJ ||}t	|| tt|}t|j
soJ ||}t	|| d S )N   r   r:  F  )r   rl   r   r   r   rz   r   rT   rS   )   )lenX4re   r   r   arangerz   Y4r*   r   r   rf   r   sparsifyr'   issparsepickleloadsdumps)rj   nr  idxr-   r   r   r  r   r   r"   test_sgd_l1I  s4   






rV  c                 C   s   t ddgddgddgddgddgg}g d}| ddd	d d
}||| t|ddggt dg | ddd	ddid
}||| t|ddggt dg d S )Nr   r   皙rc   rb   rT   rT   rT   rS   rS   r   r   F)rl   r   r   r  r:  rT   r|   rS   re   r   r*   r   r   rj   r-   r.   r   r   r   r"   test_class_weightsn  s   ("r[  c                 C   s   ddgddgddgddgg}g d}| ddd d}| || ddgddgg}ddg}| dddddd}| || t|j|jd	d
 d S )NrT   r   )r   r   rT   rT   r   r   rl   r   r  rW   r5  rU   r   )r*   r   r   )rj   r-   r.   r   clf_weightedr   r   r"   test_equal_class_weight  s   r^  c                 C   sN   | ddddid}t t |tt W d    d S 1 s w   Y  d S )Nr   r   r   rW   r\  r   r   r   r   r"   test_wrong_class_weight_label  s   "r_  c                 C   s   ddd}t jd}|tjd }t |}|tdk  |d 9  < |tdk  |d 9  < | dd|d	}| ddd
}|jtt|d |jtt|d t	|j
|j
 d S )Ng333333?r9  )rT   rU   r   rT   rU   r   r   r\  r  sample_weight)re   r   r   random_samplerN  rg   r   r*   rL  r   r   )rj   class_weightsr  sample_weightsmultiplied_togetherr   r   r   r   r"   test_weights_multiplied  s   

rf  c                 C   s  t jt j}}t|}t|jd }tjd}|	| || }|| }| ddd dd
||}tj|||dd}t|d	d
d | ddddd
||}tj|||dd}t|d	d
d t|j|jd ||dkd d f }||dk }	t|g|gd  }
t|g|	gd  }| dd dd}|
|
| ||}tj||ddd	k sJ | dddd}|
|
| ||}tj||ddd	ksJ d S )Nr      r   r   F)rl   r   r  rz   weightedr   r_   rT   r   r  r   )r   r  rz   )r   r   r   r   re   rM  rg   r   r   rz   r*   r   f1_scorer   r   r   r   vstackconcatenate)rj   r-   r.   rU  r  r   f1clf_balancedX_0y_0X_imbalancedy_imbalancedy_predr   r   r"   test_balanced_weight  s<   


rs  c                 C   s   t ddgddgddgddgddgg}g d}| ddd	d
}||| t|ddggt dg |j||dgd dgd  d t|ddggt dg d S )Nr   r   rW  rc   rb   rX  r   r   Frl   r   r   r:  rT   r|   rV   rU   r`  rS   rY  rZ  r   r   r"   test_sample_weights  s   ( "ru  c                 C   s|   | t tfv r| dddd}n| ttfv r| dddd}tt |jtt	t
dd W d    d S 1 s7w   Y  d S )Nr   r   Frt  )nur   r   r   r`  )rK   rN   rH   rQ   r   r   r   r*   r-   r   re   rM  r   r   r   r"   test_wrong_sample_weights  s   "rw  c                 C   sF   | dd}t t |tt W d    d S 1 sw   Y  d S )Nrd   r   )r   r   r   r4   r   r   r   r   r   r"   test_partial_fit_exception  s   
"rx  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks.J |jjdks6J |ddggjdksCJ t	|jj
}|t |d  t|d   t	|jj
}|scJ ||t}t|t d S )Nr   rV   rd   r   r   rT   r'  )r-   rg   re   r   r   r4   r   r   r6   idr   r   r   r   r   )rj   thirdr   r   id1id2rr  r   r   r"   test_partial_fit_binary  s   

 
r}  c                 C   s   t jd d }| dd}tt}|jt d | td | |d |jjdt jd fks.J |jjdks6J |ddggjdksCJ t	|jj
}|t |d  t|d   t	|jj
}|scJ |d S )	Nr   rV   rd   r   r   rT   r   r  )r  rg   re   r   r  r4   r   r   r6   ry  r   )rj   rz  r   r   r{  r|  r   r   r"   test_partial_fit_multiclass  s   

 r~  c                 C   s   t jd d }| dt jd d}tt}|jt d | td | |d |jjdt jd fks2J |jjdks:J |t |d  t|d   |jjdt jd fksUJ |jjdks]J d S )Nr   rV   rd   )rl   r   r   rT   r   )r  rg   re   r   r  r4   r   r   )rj   rz  r   r   r   r   r"   #test_partial_fit_multiclass_average+  s   
 r  c                 C   s"   |  }| tt |tt d S r&   )r*   r  r  r4   r   r   r   r"   test_fit_then_partial_fit:  s   r  c                 C   s   t ttftttffD ]K\}}}| ddd|dd}||| ||}|j}t	
|}| dd|dd}tdD ]
}	|j|||d q7||}
|j|ksNJ t||
dd q
d S )Nrd   rU   F)rl   ry   r   r{   rz   rl   ry   r{   rz   r   r   )r-   r   r   r  r  r  r*   r6   r   re   r   r   r4   r   )rj   r   X_r   T_r   rr  tr   rt   y_pred2r   r   r"   "test_partial_fit_equal_fit_classifD  s   


r  c                 C   s   t jd}| dddd|d}|tt dt |ttkks#J | dddd|d}|tt dt |ttkks@J | dd	|d
}|tt dt |ttkks[J | dddd|d}|tt dt |ttkksxJ d S )NrT   rd   r   r   epsilon_insensitive)rl   r{   ry   r   r   rc   squared_epsilon_insensitivehuber)rl   r   r   r   )re   r   r   r*   r-   r   rC  r   )rj   r   r   r   r   r"   test_regression_lossesW  s>    r  c                 C   s   t | ttd d S )Nr   )r   r  r  r   r   r   r"   test_warm_start_multiclass}  s   r  c                 C   s\   | ddd}| tt t|dsJ dd t tD }| td d d df | d S )Nrd   Fr   r   c                 S   s   g | ]}d dg| qS )hamspamr   )r   rt   r   r   r"   r     s    z%test_multiple_fit.<locals>.<listcomp>rS   )r*   r-   r   r   r   fit_transform)rj   r   r.   r   r   r"   test_multiple_fit  s
    r  c                 C   sL   | dddd}| ddgddgddggg d |jd |jd ks$J d S )Nr   rU   Frt  r   rT   )r   rT   rU   )r*   r   r   r   r   r"   test_sgd_reg  s   "r  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|| t| ||||\}
}t|	j|
dd t	|	j
|dd d S )Nr|   rd   r   r   r   r   r   r   TrT   Fr  r   r   )re   r   r   r  ri   r*   rx   r   r   r   r   rj   rk   rl   r  r  r  r-   r	  r.   r   rp   rr   r   r   r"   $test_sgd_averaged_computed_correctly  s,   r  c              
   C   s   d}d}d}d}t jd}|j||fd}|j|d}t ||}| dd||d	d
d	dd}	|	|d t|d  d d  |d t|d   |	|t|d d  d d  |t|d d   t| ||||\}
}t|	j	|
dd t
|	jd |dd d S )Nr|   rd   r   r   r   r   r   r   TrT   Fr  rU   r   r   )re   r   r   r  ri   r4   intrx   r   r   r   r   r  r   r   r"   test_sgd_averaged_partial_fit  s.   44r  c              
   C   s   d}d}| dd||ddddd}t jd	 }|td t|d
  d d  t d t|d
   |tt|d
 d  d d  t t|d
 d   t| tt ||\}}t|j|dd t|j	|dd d S )Nr|   rd   r   r   TrT   Fr  r   rU   r   r   )
r   rg   r4   r   r  rx   r   r   r   r   )rj   rk   rl   r   r  rp   rr   r   r   r"   test_average_sparse  s$   
44r  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
}||| |||}|dks7J d|  ||d  }| dddd	d
}||| |||}|dks_J d S )Nr   r   r   rT   rW   r   r   r   F)r   rl   r   r   ra   	re   r   r   linspacer  r   r*   scorerandn	rj   xminxmaxr  r  r-   r.   r   r  r   r   r"   test_sgd_least_squares_fit  s   r  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| dddd	d
d}||| |||}|dks8J d|  ||d  }| dddd	d
d}||| |||}|dksaJ d S )Nr  r   r   rT   rW   r  rd   r   r   Fr   epsilonrl   r   r   ra   r  r  r   r   r"   test_sgd_epsilon_insensitive  s4   r  c           	      C   s   d\}}d}t jd}t ||||d}d|  }| ddddd	d
}||| |||}|dks8J d|  ||d  }| ddddd	d
}||| |||}|dksaJ d S )Nr  r   r   rT   rW   r  r   r   Fr  ra   r  r  r   r   r"   test_sgd_huber_fit3  s   r  c              	   C   s   d\}}t jd}|||}||}t ||}dD ]4}dD ]/}tj||dd}	|	|| | dd||dd	}
|
|| d
||f }t|	j	|
j	d|d q!qd S )N)r   r   r   )rd   r|   )rW   r]   rc   F)rl   r   r   r   2   )r   r   rl   r   r   zNcd and sgd did not converge to comparable results for alpha=%f and l1_ratio=%frU   )r   r   )
re   r   r   r  ri   r   
ElasticNetr*   r   r   )rj   r  r  r  r-   ground_truth_coefr.   rl   r   cdsgdr   r   r   r"   test_elasticnet_convergenceK  s4   
r  c                 C   s   t jd d }| dd}|t d | td |  |jjt jd fks&J |jjdks.J |ddggjdks;J t|jj}|t |d  t|d   t|jj}|s[J |d S )Nr   rV   rd   r   rT   r'  )	r-   rg   r4   r   r   r   r   ry  r   )rj   rz  r   r{  r|  r   r   r"   test_partial_fitm  s   
r  c                 C   s   | ddd|dd}| tt |t}|j}| dd|dd}tdD ]}|tt q#|t}|j|ks8J t||dd d S )Nrd   rU   F)rl   r   ry   r{   rz   r  r   )	r*   r-   r   r   r   r   r   r4   r   )rj   r   r   rr  r  rt   r  r   r   r"   test_partial_fit_equal_fit~  s   

r  c                 C   s0   | dd}|j dd |jd d dksJ d S )Nr\   )r  r   r  rT   )r   r+  r   r   r   r"   test_loss_function_epsilon  s   
r  c                 C   s  |d u rt |jd }n|}t |jd }|}d| }	d}
d}| tkr'd}t|D ]T\}}t ||}||	7 }|dkr@d}nd}|tdd|| d  9 }||| |  7 }|	|||   | 7 }	||9 }||7 }||d  }|
|9 }
|
|	7 }
|
|d  }
q+|d|
 fS )NrT   rb   rc   rd   rS   r   rU   )re   rf   rg   rQ   rh   ri   max)rj   r-   rk   rv  r~   r   coefr"  offsetrq   rr   rs   rt   ru   rv   rw   r   r   r"   asgd_oneclass  s4   r  c                 C   s   | ddd|d}| | | ddd|d}|j ||j |j d | dddd|d}| | |j|jks9J t|j|j |jdd	 | | |j|jksSJ t|j|j d S )
NrW   rd   F)rv  ry   rz   r{   r   r~   r   T)rv  ry   rz   r   r{   rv  )r*   r   r   r   r   r   r   )rj   r-   r   r   r   r   r   r   r"   _test_warm_start_oneclass  s   


r  c                 C   s   t | t| d S r&   )r  r-   r   r   r   r"   test_warm_start_oneclass  s   r  c                 C   sN   | dd}t |}|jdd |t | dd}|t t|j|j d S )NrW   r  r   )r   r   r*   r-   r   r   r   r   r   r"   test_clone_oneclass  s   



r  c                 C   s   t jd d }| dd}|t d |  |jjt jd fks!J |jjdks)J |ddggjdks6J |j}|t |d   |j|u sIJ tt |t d d df  W d    d S 1 sew   Y  d S )Nr   rV   r   r  rT   r'  )	r-   rg   r4   r   r   r   r   r   r   )rj   rz  r   previous_coefsr   r   r"   test_partial_fit_oneclass  s   
"r  c           	      C   s   | ddd|dd}| t |t}|j}|j}|j}| ddd|dd}tdD ]}|t q)|t}|j|ks=J t	|| t	|j| t	|j| d S )N皙?rU   rd   F)rv  r   ry   r{   rz   rT   )rv  ry   r   r{   rz   )
r*   r-   r6   r   r   r   r   r   r4   r   )	rj   r   r   y_scoresr  r  r  r   	y_scores2r   r   r"   #test_partial_fit_equal_fit_oneclass   s   



r  c                 C   s   d}d}| dd||ddd}| dd||ddd}| t | t t| t|||j |jd	\}}t|j |  t|j| d S )
Nr|   r  r   r   rU   F)r   r{   ry   rv  r   rz   rT   r  )r*   r-   r  r   r   r   r   )rj   ry   rv  r   r   r"  average_offsetr   r   r"   *test_late_onset_averaging_reached_oneclass  s(   
	

r  c           
   	   C   sz   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}|| t| |||\}}	t|j| t|j|	 d S )Nr|   r  r   r   r   r   r   TrT   Fr{   ry   rv  r   r   r   rz   )	re   r   r   r  r*   r  r   r   r   
rj   rk   rv  r  r  r  r-   r   r"  r  r   r   r"   -test_sgd_averaged_computed_correctly_oneclass6  s&   

r  c           
   	   C   s   d}d}d}d}t jd}|j||fd}| d||dd	dd
d}||d t|d  d d   ||t|d d  d d   t| |||\}}	t|j| t|j	|	 d S )Nr|   r  r   r   r   r   r   TrT   Fr  rU   )
re   r   r   r  r4   r  r  r   r   r   r  r   r   r"   &test_sgd_averaged_partial_fit_oneclassQ  s(   "
"r  c              	   C   s   d}d}| d||ddddd}t jd }|t d t|d	   |t t|d	 d   t| t ||\}}t|j| t|j| d S )
Nr|   rd   r   TrT   Fr  r   rU   )r   rg   r4   r  r  r   r   r   )rj   rk   rv  r   r  r"  r  r   r   r"   test_average_sparse_oneclassm  s"   

r  c                  C   s   t ddgddgddgg} t ddgddgg}tdddddd}||  t|jt d	d
g |jd dks;J ||}t|t ddg |||j }t||| |	|}t
|t ddg d S )NrR   rS   rT   rW   rU   r   F)rv  ry   r{   rz   r   g      g      ?r   rX   g      g      ?)re   r   rH   r*   r   r   r   score_samplesr6   r   r   )X_trainX_testr   scoresdecr  r   r   r"   test_sgd_oneclass  s   



r  c                  C   s*  d} d}d}t j|}d|dd }t j|d |d f }d|dd }t j|d |d f }t|d| d	}|| ||}||	d
d}	d}
t
||d}t| dd|
|d d}t||}|| ||}||	d
d}t ||kdksJ t t |	|fd }|dksJ d S )Nr  r   r   r9    rU   r   rbf)gammakernelrv  rT   rS      )r  r   T)rv  rz   r   r   r   r   ra   r5  r\   )re   r   r   r  r_r   r*   r   r6   r  r	   rH   r   rC  corrcoefrk  )rv  r  r   r  r-   r  r  r   y_pred_ocsvm	dec_ocsvmr   	transformclf_sgdpipe_sgdy_pred_sgdocsvmdec_sgdocsvmr  r   r   r"   test_ocsvm_vs_sgdocsvm  s:   




r  c                  C   s   t jddddd\} }tddd dd	d
d| |}tdddd
d d| |}t|j|j tddd ddd
d| |}tdddd
d d| |}t|j|j d S )Nr   r   r   i  )r  r  n_informativer   r|   r   rg  gA?r   )rl   r   r   r   r   r   r   )rl   r   r   r   r   g|=r   )r   make_classificationrK   r*   r   r   )r-   r.   est_enest_l1est_l2r   r   r"   test_l1_ratio  sF   


r  c            	   	   C   sV  t jdd t jd} d}d}| j||fd}|d d d df  d9  < t | s0J t |}t | s?J | j|d}t 	||d	k
t j}tt |dd
g tdddd}||| t |j srJ d}tjt|d ||| W d    n1 sw   Y  W d    d S W d    d S 1 sw   Y  d S )Nraiser?  r   r   r   r   rU   gu <7~rb   rT   r   r   r  )rl   r   r   zwFloating-point under-/overflow occurred at epoch #.* Scaling input data with StandardScaler or MinMaxScaler might help.r   )re   r<  r   r   r  isfiniter?  r   r  ri   astypeint32r   r   rK   r*   r   r   r   r   )	r  r  r  r-   X_scaledground_truthr.   model	msg_regxpr   r   r"   test_underflow_or_overlow  s.   !"r  c                  C   sn   t ddddddddd d		} tjd
d | tjtj W d    n1 s&w   Y  t| j	 s5J d S )Nr   r   Tr   r9  rd   r|   r   )	r   r   rz   r   r   rl   ry   r   r   r  r  )
rK   re   r<  r*   r   r   r   r  r   r?  )r  r   r   r"   'test_numerical_stability_large_gradient  s   r  r   )r   r   r   c              	   C   sj   t ddd| dd dd}tjdd |tjtj W d    n1 s$w   Y  t|jt	|j d S )	Ng     j@r   r   Frg  )rl   r{   ry   r   rz   r   r   r  r  )
rK   re   r<  r*   r   r   r   r   r   
zeros_like)r   r  r   r   r"   test_large_regularization!  s   	r  c                  C   s  t  tj} tjdk}d}td d|d}|| | ||jks"J d}tdd|d}|| | ||jks8J |jdks?J tdd|d}|| | |j|jksTJ |jdks[J tdd	dd
}d}tj	t
|d || | W d    n1 s|w   Y  |jdksJ d S )NrT   r   r   )r   r   r   rI  r   r   rV   r|   )r   r   r   zhMaximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.r   )r   r  r   r   r   rK   r*   r   r   warnsr   )r-   r.   r   model_0model_1model_2model_3warning_messager   r   r"   test_tol_parameter3  s,   
r  c                 C   s:   |D ]\}}}}t | ||| t | ||| qd S r&   )r   py_losspy_dloss)loss_functioncasesrv   r.   expected_lossexpected_dlossr   r   r"   _test_loss_commonW  s   r  c                  C   s<   t d} g d}t| | t d} g d}t| | d S )Nrc   ))g?rc   rb   rb          r   rb   rb   )rc   rc   rb   r   )r   r   rb   rc   )rW   rc   rW   r   )r   r         @rc   )rX   r   rW   rc   )rb   rc   rT   r   rb   )rc   rc   rb   rb   )r8  r   rb   rb   )rb   rc   rb   r   )rb   r   rb   rc   )rW   r   rW   rc   )r   r   r   rc   )rX   rc   rW   r   )r   rc   rc   r   )sgd_fastHinger  r   r  r   r   r"   test_loss_hinge_  s   


r   c                  C       t d} g d}t| | d S )Nrc   )r  r  )rc   r         @r  r   rc   r        )rW   rc   g      ?r   rW   r   g      @r  )r  SquaredHinger  r  r   r   r"   test_gradient_squared_hinge  s   
	r  c                  C   s   t  } g d}t| | d S )N)r  )r   r   rb   rb   )r   rc   rb   rb   )rb   rc   rc   r  r  r  )r  rc      r  )g      rc      r  )r  ModifiedHuberr  r  r   r   r"   test_loss_modified_huber  s   r  c                  C   r  )Nr   )rb   rb   rb   rb   r   rb   rb   rb   gffffff r  rb   rb   gffffff@r  rb   rb   )皙@r   r   rc   )r   r   333333@rc   )r   r  r   r   )r  rc   r  r   )r  EpsilonInsensitiver  r  r   r   r"   test_loss_epsilon_insensitive     
r  c                  C   r  )Nr   )r  r  r  r  )r  r   rd   r:  )r   r   R @g333333@)r   r  rd   gɿ)r  rc   r  g333333)r  SquaredEpsilonInsensitiver  r  r   r   r"   %test_loss_squared_epsilon_insensitive  r  r  c               	   C   sf   t dddddddd} | tjtj | j| jksJ | j| jd k s%J | tjtjd	ks1J d S )
Nr|   r   Tr   r   rU   )rl   r   r   r   r   r   r%  r   r]   )rK   r*   r   r   r   r   r   r  )r   r   r   r"   0test_multi_thread_multi_class_and_early_stopping  s   	r  c                  C   s\   t dddg dd} tdddd	d
}t|| ddd	d}|tjtj |jdks,J d S )NrJ  r  )r   r   r  )rl   r   rd   r   Tr   )r   r   r   r   r   rU   )n_iterr%  r   r]   )	re   logspacerK   r   r*   r   r   r   best_score_)
param_gridr   searchr   r   r"   -test_multi_core_gridsearch_and_early_stopping  s   r  backend)lokymultiprocessing	threadingc                 C   s   t jd}tjdddd|d}|dd}tdd	dd
}||| tdddd
}tj| d ||| W d    n1 sAw   Y  t	|j
|j
 d S )Nr   r  rI  g{Gz?csr)densityr,  r   r   r   rT   )r   r%  r   rJ  )r   )re   r   r   r'   choicerK   r*   joblibparallel_backendr   r   )r   r   r-   r.   clf_sequentialclf_parallelr   r   r"   'test_SGDClassifier_fit_for_all_backends  s   r+  c                 C   sN  | t jkrtj|d\}}ntj|d\}}| |dd}tt |||j	}|j
dks0J W d    n1 s:w   Y  | |dd}tt |||j	}|j
dksYJ W d    n1 scw   Y  t|| | |d dd}tt |||j	}|j
dksJ W d    n1 sw   Y  t||  dksJ d S )N)r   rT   )r   r   rc   )r   rA   r   make_regressionr  r   r  r   r*   r   r   r   re   absr  )r   global_random_seedr-   r.   estcoef_same_seed_acoef_same_seed_bcoef_other_seedr   r   r"   test_sgd_random_state  s(   

r3  c           	      C   s   t jt j}}|jd }d}tjddd|d}ttjd}| 	td| |
|| |jd d	d
 \}}|jd t|| ksBJ |jd t|| ksOJ dS )ziTest that data passed to validation callback correctly subsets.

    Non-regression test for #23255.
    r   r:  Tr|   r   )r   r   r   r   )side_effect_ValidationScoreCallbackrT   rV   N)r   r   r   rg   r   rK   r   r   r5  setattrr*   	call_argsr  )	monkeypatchr-   r   r  r   r   mockX_valy_valr   r   r"   &test_validation_mask_correctly_subsets-  s   
r<  c                  C   st   t jt j} }t|}d}tjd|dd}d}tjt	|d |j
| ||d W d    d S 1 s3w   Y  d S )Nr   Tr   )r   r   r   z\The sample weights for validation set are all zero, consider using a different random state.r   r`  )r   r   r   re   r  r   rK   r   r   r   r*   )r-   r   ra  r   r   error_messager   r   r"   (test_sgd_error_on_zero_validation_weightE  s   
"r>  c                 C   s   | dd tt dS )z!non-regression test for gh #25249rT   )verboseN)r*   r-   r   )r   r   r   r"   test_sgd_verboseX  s   r@  SGDEstimator	data_typec                 C   s>   t |}tjt|d}|  }||| |jj|ksJ d S )Ndtype)r-   r  re   r   r   r*   r   rD  )rA  rB  _X_Y	sgd_modelr   r   r"   test_sgd_dtype_match^  s
   
rH  c                 C   sz   t jtjd}tjttjd}t jtjd}tjttjd}| dd}||| | dd}||| t|j	|j	 d S )NrC  r   )r   )
r-   r  re   float64r   r   float32r*   r   r   )rA  X_64Y_64X_32Y_32sgd_64sgd_32r   r   r"   test_sgd_numerical_consistencyr  s   

rQ  c                  C   s   t  } t| jdksJ dS )z}Check that SGDOneClassSVM has the correct estimator type.

    Non-regression test for if the mixin was not on the left.
    outlier_detectorN)rH   r   estimator_type)	sgd_ocsvmr   r   r"   %test_sgd_one_class_svm_estimator_type  s   rU  )Nrb   )rQ  unittest.mockr   r'  numpyre   r   scipy.sparsesparser'   sklearnr   r   r   sklearn.baser   r   sklearn.exceptionsr   sklearn.kernel_approximationr	   sklearn.linear_modelr
   r  r   sklearn.model_selectionr   r   r   sklearn.pipeliner   sklearn.preprocessingr   r   r   r   sklearn.svmr   sklearn.utilsr   sklearn.utils._testingr   r   r   r   r#   rK   r$   rA   r?   rH   rG   rN   rP   rQ   r   r-   r   r   r   r  r  r  r  r   r   rL  rN  	load_irisr   r   r   true_result5rx   r   markparametrizer   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rf   r   r   r   r
  r  r  r  r  r#  r$  r&  r(  r4  rG  rV  r[  r^  r_  rf  rs  ru  rw  rx  r}  r~  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r  r  r  r  r  r  r+  r3  r<  r>  r@  rJ  rI  rH  rQ  rU  r   r   r   r"   <module>   s   
..	"




+


)





	
	




	
!
	










!
G
$




.






	
%



 
 


#

!

&



 


)#&
$!
#
#
