a
    yfw                     @   sv  d Z ddlZddlZddlZddlZddlZddlZddlmZmZ ddl	m	Z	m
Z
 ddlmZ ddlZddlZddlmZ ddlmZmZ ddlmZmZ dd	lmZmZ dd
lmZmZ ddlmZmZm Z m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z,m-Z-m.Z.m/Z/m0Z0 ddl1m2Z2m3Z3 ddl4m5Z5 ddl6m7Z7m8Z8m9Z9m:Z:m;Z;m<Z<m=Z=m>Z>m?Z?m@Z@ G dd dZAdS )zz
Train a model on a dataset.

Usage:
    $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
    N)copydeepcopy)datetime	timedelta)Path)distributed)nnoptim)get_cfgget_save_dir)check_cls_datasetcheck_det_dataset)attempt_load_one_weightattempt_load_weights)DEFAULT_CFG
LOCAL_RANKLOGGERRANKTQDM__version__	callbacks	clean_urlcolorstremojis	yaml_save)check_train_batch_size)	check_amp
check_filecheck_imgszcheck_model_file_from_stem
print_args)ddp_cleanupgenerate_ddp_command)get_latest_run)
	TORCH_2_4EarlyStoppingModelEMAautocast$convert_optimizer_state_dict_to_fp16
init_seeds	one_cycleselect_devicestrip_optimizertorch_distributed_zero_firstc                   @   sX  e Zd ZdZeddfddZedddZeddd	Zedd
dZ	dd Z
dd Zdd Zdd ZdVddZdd Zdd Zdd Zdd Zdd  Zd!d" Zd#d$ Zd%d& Zd'd( ZdWd*d+Zd,d- ZdXd1d2ZdYd3d4ZdZd5d6Zd7d8 Zd9d: Zd;d< Zd=d> Z d?d@ Z!dAdB Z"dCdD Z#d[dEdFZ$dGdH Z%dIdJ Z&dKdL Z'dMdN Z(d\dTdUZ)dS )]BaseTraineraA  
    A base class for creating trainers.

    Attributes:
        args (SimpleNamespace): Configuration for the trainer.
        validator (BaseValidator): Validator instance.
        model (nn.Module): Model instance.
        callbacks (defaultdict): Dictionary of callbacks.
        save_dir (Path): Directory to save results.
        wdir (Path): Directory to save weights.
        last (Path): Path to the last checkpoint.
        best (Path): Path to the best checkpoint.
        save_period (int): Save checkpoint every x epochs (disabled if < 1).
        batch_size (int): Batch size for training.
        epochs (int): Number of epochs to train for.
        start_epoch (int): Starting epoch for training.
        device (torch.device): Device to use for training.
        amp (bool): Flag to enable AMP (Automatic Mixed Precision).
        scaler (amp.GradScaler): Gradient scaler for AMP.
        data (str): Path to data.
        trainset (torch.utils.data.Dataset): Training dataset.
        testset (torch.utils.data.Dataset): Testing dataset.
        ema (nn.Module): EMA (Exponential Moving Average) of the model.
        resume (bool): Resume training from a checkpoint.
        lf (nn.Module): Loss function.
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
        best_fitness (float): The best fitness value achieved.
        fitness (float): Current fitness value.
        loss (float): Current loss value.
        tloss (float): Total loss value.
        loss_names (list): List of loss names.
        csv (Path): Path to results CSV file.
    Nc                 C   s  t ||| _| | t| jj| jj| _d| _d| _i | _t	| jj
d t | jjd t| j| _| jj| j_| jd | _tdv r| jjddd t| j| j_t| jd t| j | jd	 | jd
  | _| _| jj| _| jj| _| jj| _d| _tdkrtt| j | jjdv r$d| j_t| jj| _t t!  | " \| _#| _$W d   n1 sb0    Y  d| _%d| _&d| _'d| _(d| _)d| _*d| _+dg| _,| jd | _-g d| _.d| _/|pt01 | _0tdv rt02|  dS )z
        Initializes the BaseTrainer class.

        Args:
            cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
            overrides (dict, optional): Configuration overrides. Defaults to None.
        N   )deterministicweights   r   T)parentsexist_okz	args.yamlzlast.ptzbest.ptr   r3   >   cpumpsZLosszresults.csv)r   r/      )3r
   argscheck_resumer+   devicebatch	validatormetricsplotsr)   seedr   r0   r   save_dirnamewdirmkdirstrr   varslastbestsave_period
batch_sizeepochsstart_epochr    typeworkersr   modelr-   r   get_datasettrainsettestsetemalf	schedulerbest_fitnessfitnesslosstlossZ
loss_namescsvplot_idxZhub_sessionr   Zget_default_callbacksZadd_integration_callbacks)selfcfg	overrides
_callbacks r`   V/var/www/html/django/DPS/env/lib/python3.9/site-packages/ultralytics/engine/trainer.py__init__]   sP    





0

zBaseTrainer.__init__)eventc                 C   s   | j | | dS )zAppends the given callback.N)r   appendr\   rc   callbackr`   r`   ra   add_callback   s    zBaseTrainer.add_callbackc                 C   s   |g| j |< dS )z9Overrides the existing callbacks with the given callback.N)r   re   r`   r`   ra   set_callback   s    zBaseTrainer.set_callbackc                 C   s    | j |g D ]}||  qdS )z>Run all existing callbacks associated with a particular event.N)r   getre   r`   r`   ra   run_callbacks   s    zBaseTrainer.run_callbacksc              
   C   s`  t | jjtr.t| jjr.t| jjd}nFt | jjttfrNt| jj}n&| jjdv r`d}ntj	
 rpd}nd}|dkrRdtjvrR| jjrtd d| j_| jjdk rtd	 d
| j_t|| \}}znz0ttd dd|  tj|dd W n* ty, } z|W Y d}~n
d}~0 0 W t| t| nt| t| 0 n
| | dS )zIAllow device='', device=None on Multi-GPU systems to default to device=0.,>   r6   r7   r   r/   r   uX   WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'F      ?uj   WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting default 'batch=16'   zDDP:z debug command  T)checkN)
isinstancer9   r;   rE   lensplittuplelisttorchcudaZis_availableosenvironrectr   warningr<   r"   infor   join
subprocessrun	Exceptionr!   	_do_train)r\   
world_sizecmdfileer`   r`   ra   train   s4    

"zBaseTrainer.trainc                    sF    j jrtd j j j _n fdd _tjj j	 jd _
dS )z,Initialize training learning rate scheduler.r/   c                    s(   t d|  j  dd jj   jj S )Nr/   r   rl   )maxrK   r9   lrf)xr\   r`   ra   <lambda>       z.BaseTrainer._setup_scheduler.<locals>.<lambda>)Z	lr_lambdaN)r9   Zcos_lrr*   r   rK   rT   r	   Zlr_schedulerZLambdaLR	optimizerrU   r   r`   r   ra   _setup_scheduler   s    zBaseTrainer._setup_schedulerc                 C   sL   t jt t dt| _dtjd< tjt	 r4dndt
ddt|d d	S )
zIInitializes and sets the DistributedDataParallel parameters for training.rv   1ZTORCH_NCCL_BLOCKING_WAITZncclZglooi0*  )seconds)backendtimeoutrankr   N)ru   rv   Z
set_devicer   r;   rw   rx   distZinit_process_groupZis_nccl_availabler   )r\   r   r`   r`   ra   
_setup_ddp   s    
zBaseTrainer._setup_ddpc                    s  |  d |  }| j| j| _|   t| jjt	r@| jjnt| jjt
rZt| jjng }dg}dd |D | }| j D ]^\ }t fdd|D rtd  d d	|_q|js|jjrtd
  d d|_qt| jj| j| _| jr2tdv r2tj }tjt| j| jd| _|t_tdkrV|dkrVtj| jdd t| j| _trztjjd| jdntj jj| jd| _!|dkrt"j#j$| jtgdd| _t%t
t&| jdr| jj'% ndd}t(| jj)||dd| j_)|| _'| j*dk r0tdkr0t+| j| jj)| j| j*d | j_,| _*| j*t%|d }	| j-| j.|	t/dd| _0tdv r| j-| j1| jj2dkrz|	n|	d ddd| _3| 4 | _5| j5j6j7| j8dd }
t9t:|
dgt;|
 | _6t<| j| _=| jj>r| ?  t%t@| jjA| j* d| _B| jjC| j* | jB | jjA }tDEt;| j0jFt%| j*| jjA | jG }| jH| j| jjI| jjJ| jjK||d | _I| L  tM| jjNd!d	 | _O| _P| Q| | jRd | jS_T|  d" d#S )$z9Builds dataloaders and optimizer on correct rank process.Zon_pretrain_routine_startz.dflc                 S   s   g | ]}d | dqS )zmodel..r`   .0r   r`   r`   ra   
<listcomp>   r   z,BaseTrainer._setup_train.<locals>.<listcomp>c                 3   s   | ]}| v V  qd S )Nr`   r   kr`   ra   	<genexpr>   r   z+BaseTrainer._setup_train.<locals>.<genexpr>zFreezing layer ''Fu>   WARNING ⚠️ setting 'requires_grad=True' for frozen layer 'zE'. See ultralytics.engine.trainer for customization of frozen layers.Tr2   )r;   r3   r/   r   )srcrv   )enabled)Z
device_idsZfind_unused_parametersstride    )r   floorZmax_dim)rO   imgszampr<   r   )rJ   r   modeobbr8   val)prefix)rO   rB   lrmomentumdecay
iterations)patienceZon_pretrain_routine_endN)Urj   setup_modelrO   tor;   set_model_attributesrp   r9   freezert   intrangenamed_parametersanyr   r{   Zrequires_gradZdtypeZis_floating_pointru   Ztensorr   r   r   Zdefault_callbacksr   r   r   	broadcastboolr$   Z
GradScalerrv   scalerr   parallelZDistributedDataParallelr   hasattrr   r   r   rJ   r   r<   get_dataloaderrQ   r   train_loaderrR   taskZtest_loaderget_validatorr=   r>   keyslabel_loss_itemsdictziprq   r&   rS   r?   plot_training_labelsroundnbs
accumulateweight_decaymathceildatasetrK   build_optimizerr   lr0r   r   r%   r   stopperstopresume_trainingrL   rU   
last_epoch)r\   r   ckptZfreeze_listZalways_freeze_namesZfreeze_layer_namesvZcallbacks_backupgsrJ   Zmetric_keysr   r   r`   r   ra   _setup_train   s    



(
(
 

(	
zBaseTrainer._setup_trainr/   c                 C   s0  |dkr|  | | | t| j}| jjdkrHtt| jj| dnd}d}d| _t		 | _
t		 | _| d td| jj d| jj d	| jj|pd  d
td| j d	| jj	r| jj	 dn
| j d  | jjr| j| jj | }| j||d |d g | j}| j  || _| d t $ td | j  W d   n1 s`0    Y  | j !  t"dkr| jj#$| t%| j}|| j| jj kr| &  | j'  t"dv rt| (  t)t%| j|d}d| _*|D ]\}}	| d |||  }
|
|krd|g}tdt+t,-|
|d| jj.| j/ g | _0t%| jj1D ]h\}}t,-|
||dkr|| jj2nd|d | 3| g|d< d|v rZt,-|
|| jj4| jj5g|d< qZt6| j7p | 8|	}	|  |	\| _9| _:t"dkr|  j9|9  _9| j*dur(| j*| | j: |d  n| j:| _*W d   n1 sF0    Y  | j;<| j9=  |
| | j0kr| >  |
}| jj	rt		 | j | jj	d k| _?t"dkrt"dkr| j?ndg}t@A|d |d | _?| j?r qt"dv rt| j*jBr| j*jBd nd}|Cddd|   |d  d| j | D ddg|dkrT| j*ntEF| j*d|	d jBd |	d  jBd R   | d! | jjGr|
| jv r| H|	|
 | d" qd#d$ t%| jj1D | _I| d% t"dv r|d | jk}| jJjK| j g d&d' | jjLs2|s2| jMjNs2| j?rB| O \| _P| _Q| jRi | S| j*| jP| jId( |  j?| M|d | jQp|O  _?| jj	r|  j?t		 | j | jj	d kO  _?| jjTs|r| U  | d) t		 }|| j
 | _|| _
| jj	rZ|| j || j d  }tVW| jj	d |  | _| j_| X  | j| j_Y|  j?|| jkO  _?| d* | Z  t"dkrt"dkr| j?ndg}t@A|d |d | _?| j?rq|d7 }qt"dv rtd+|| j d  d,t		 | j d d-d. | [  | jjGr| \  | d/ | Z  | d0 dS )1z=Train completed, evaluate and plot if specified by arguments.r/   r   d   r3   NZon_train_startzImage sizes z train, z val
Using z' dataloader workers
Logging results to boldz
Starting training for z	 hours...z
 epochs...r8   Zon_train_epoch_startignorer2   )totalZon_train_batch_start        Z
initial_lrr   r   i  z%11s%11sz%11.4g/z.3gGclsZimgZon_batch_endZon_train_batch_endc                 S   s    i | ]\}}d | |d qS )zlr/pgr   r`   )r   Zirr   r`   r`   ra   
<dictcomp>  r   z)BaseTrainer._do_train.<locals>.<dictcomp>Zon_train_epoch_end)yamlncr9   namesr   Zclass_weights)include)r>   Zon_model_saveon_fit_epoch_end
z epochs completed in z.3fz hours.Zon_train_endZteardown)]r   r   rq   r   r9   Zwarmup_epochsr   r   Z
epoch_timetimeZepoch_time_startZtrain_time_startrj   r   r{   r   Znum_workersr   rA   rK   close_mosaicr[   extendrL   r   	zero_gradepochwarningscatch_warningssimplefilterrU   steprO   r   r   ZsamplerZ	set_epoch	enumerate_close_dataloader_mosaicresetprogress_stringr   rY   r   npZinterpr   rJ   r   Zparam_groupswarmup_bias_lrrT   Zwarmup_momentumr   r'   r   preprocess_batchrX   
loss_itemsr   scaleZbackwardoptimizer_stepr   r   Zbroadcast_object_listshapeset_description_get_memoryru   Z	unsqueezer?   plot_training_samplesr   rS   Zupdate_attrr   r   Zpossible_stopvalidater>   rW   save_metricsr   save
save_modelr   r   r   r   _clear_memory
final_evalplot_metrics)r\   r   nbnwZlast_opt_stepZbase_idxr   Zpbarir<   nixijr   Zbroadcast_listZloss_lengthZfinal_epochtZmean_epoch_timer`   r`   ra   r   C  s   


&



 




*






,*
"

($







"$"
$

 






zBaseTrainer._do_trainc                 C   s<   | j jdkrtj }n| j jdkr*d}n
tj }|d S )z)Get accelerator memory utilization in GB.r7   r6   r   g    eA)r;   rM   ru   r7   Zdriver_allocated_memoryrv   Zmemory_reserved)r\   Zmemoryr`   r`   ra   r     s    
zBaseTrainer._get_memoryc                 C   s>   t   | jjdkr tj  n| jjdkr0dS tj  dS )z0Clear accelerator memory on different platforms.r7   r6   N)gcZcollectr;   rM   ru   r7   Zempty_cacherv   r   r`   r`   ra   r     s    zBaseTrainer._clear_memoryc                 C   s*   ddl }dd || jjdd D S )z*Read results.csv into a dict using pandas.r   Nc                 S   s   i | ]\}}|  |qS r`   )stripr   r   r   r`   r`   ra   r     r   z0BaseTrainer.read_results_csv.<locals>.<dictcomp>rt   )Zorient)ZpandasZread_csvrZ   to_dictitems)r\   pdr`   r`   ra   read_results_csv  s    zBaseTrainer.read_results_csvc                 C   s   ddl }| }t| j| jdt| jj | jj	t
t| j t| ji | jd| ji|  t  tddd| | }| j| | j| jkr| j| | jdkr| j| j dkr| jd| j d | dS )	z9Save model training checkpoints with additional metadata.r   NrW   z*AGPL-3.0 (https://ultralytics.com/license)zhttps://docs.ultralytics.com)r   rV   rO   rS   updatesr   Z
train_argsZtrain_metricstrain_resultsdateversionlicensedocsr   .pt)ioBytesIOru   r   r   rV   r   rS   Zhalfr  r(   r   
state_dictrF   r9   r>   rW   r  r   now	isoformatr   getvaluerG   write_bytesrH   rI   rC   )r\   r  bufferZserialized_ckptr`   r`   ra   r     s2    
zBaseTrainer.save_modelc              
   C   s   z`| j jdkrt| j j}nB| j jdd dv s>| j jdv r^t| j j}d|v r^|d | j _W nF ty } z.ttdt	| j j d| |W Y d	}~n
d	}~0 0 || _|d
 |
dp|
dfS )zz
        Get train, val path from data dict if it exists.

        Returns None if data format is not recognized.
        Zclassifyr   r3   >   Zymlr   >   segmentr   detectZposeZ	yaml_filez	Dataset 'u   ' error ❌ Nr   r   test)r9   r   r   datarr   r   r   RuntimeErrorr   r   ri   )r\   r!  r   r`   r`   ra   rP      s    "8zBaseTrainer.get_datasetc                 C   s   t | jtjjrdS | jd }}d}t| jdrJt| j\}}|j}n"t | j	j
ttfrlt| j	j
\}}| j||tdkd| _|S )z(Load/create/download model for any task.Nr  r3   )r]   r1   verbose)rp   rO   ru   r   ModulerE   endswithr   r   r9   Z
pretrainedr   	get_modelr   )r\   r]   r1   r   _r`   r`   ra   r   7  s    zBaseTrainer.setup_modelc                 C   s`   | j | j tjjj| j dd | j 	| j | j 
  | j  | jr\| j
| j dS )zVPerform a single step of the training optimizer with gradient clipping and EMA update.g      $@)Zmax_normN)r   Zunscale_r   ru   r   utilsZclip_grad_norm_rO   
parametersr   updater   rS   r   r`   r`   ra   r   F  s    

zBaseTrainer.optimizer_stepc                 C   s   |S )zRAllows custom preprocessing model inputs and ground truths depending on task type.r`   )r\   r<   r`   r`   ra   r   P  s    zBaseTrainer.preprocess_batchc                 C   sD   |  | }|d| j    }| jr6| j|k r<|| _||fS )z
        Runs validation on test set using self.validator.

        The returned dict is expected to contain "fitness" key.
        rW   )r=   poprX   detachr6   numpyrV   )r\   r>   rW   r`   r`   ra   r   T  s
    
zBaseTrainer.validateTc                 C   s   t ddS )z>Get model and raise NotImplementedError for loading cfg files.z3This task trainer doesn't support loading cfg filesNNotImplementedError)r\   r]   r1   r#  r`   r`   ra   r&  `  s    zBaseTrainer.get_modelc                 C   s   t ddS )zHReturns a NotImplementedError when the get_validator function is called.z1get_validator function not implemented in trainerNr.  r   r`   r`   ra   r   d  s    zBaseTrainer.get_validatorrm   r   r   c                 C   s   t ddS )z6Returns dataloader derived from torch.data.Dataloader.z2get_dataloader function not implemented in trainerNr.  )r\   Zdataset_pathrJ   r   r   r`   r`   ra   r   h  s    zBaseTrainer.get_dataloaderc                 C   s   t ddS )zBuild dataset.z1build_dataset function not implemented in trainerNr.  )r\   Zimg_pathr   r<   r`   r`   ra   build_datasetl  s    zBaseTrainer.build_datasetc                 C   s   |durd|iS dgS )z
        Returns a loss dict with labelled training loss items tensor.

        Note:
            This is not needed for classification but necessary for segmentation & detection
        NrX   r`   )r\   r   r   r`   r`   ra   r   p  s    zBaseTrainer.label_loss_itemsc                 C   s   | j d | j_dS )z2To set or update model parameters before training.r   N)r!  rO   r   r   r`   r`   ra   r   y  s    z BaseTrainer.set_model_attributesc                 C   s   dS )z.Builds target tensors for training YOLO model.Nr`   )r\   predstargetsr`   r`   ra   build_targets}  s    zBaseTrainer.build_targetsc                 C   s   dS )z.Returns a string describing training progress. r`   r   r`   r`   ra   r     s    zBaseTrainer.progress_stringc                 C   s   dS )z,Plots training samples during YOLO training.Nr`   )r\   r<   r  r`   r`   ra   r     s    z!BaseTrainer.plot_training_samplesc                 C   s   dS )z%Plots training labels for YOLO model.Nr`   r   r`   r`   ra   r     s    z BaseTrainer.plot_training_labelsc                 C   s   t | t |  }}t|d }| j r4dnd| tdg|  dd }t| jd@}|	|d| t| j
d g|  d d  W d	   n1 s0    Y  d	S )
z%Saves training metrics to a CSV file.r/   r4  z%23s,r   rk   r   az%23.5g,N)rt   r   valuesrq   rZ   existsrs   rstripopenwriter   )r\   r>   r   valsnsfr`   r`   ra   r     s
    .zBaseTrainer.save_metricsc                 C   s   dS )z"Plot and display metrics visually.Nr`   r   r`   r`   ra   r     s    zBaseTrainer.plot_metricsc                 C   s    t |}|t d| j|< dS )z3Registers plots (e.g. to be consumed in callbacks).)r!  	timestampN)r   r   r?   )r\   rB   r!  pathr`   r`   ra   on_plot  s    zBaseTrainer.on_plotc                 C   s   i }| j | jfD ]}| r|| j u r0t|}q|| ju rd}t|||v rV||| indd td| d | jj| jj_| j|d| _	| j	
dd | d qdS )	zIPerforms final evaluation and validation for object detection YOLO model.r  N)r  z
Validating z...)rO   rW   r   )rG   rH   r7  r,   r   r{   r9   r?   r=   r>   r+  rj   )r\   r   r>  r   r`   r`   ra   r     s    


 zBaseTrainer.final_evalc              
   C   s   | j j}|rzt|ttfo&t| }t|r6t|nt }t|j }t|d  sd| j j	|d< d}t
|| _ t| | j _| j _dD ]}||v rt| j |||  qW n. ty } ztd|W Y d}~n
d}~0 0 || _dS )zCCheck if resume checkpoint exists and update arguments accordingly.r!  T)r   r<   r;   r   zzResume checkpoint not found. Please pass a valid checkpoint to resume from, i.e. 'yolo train resume model=path/to/last.pt'N)r9   resumerp   rE   r   r7  r   r#   r   r!  r
   rO   setattrr   FileNotFoundError)r\   r^   rB  r7  rG   Z	ckpt_argsr   r   r`   r`   ra   r:     s*    

zBaseTrainer.check_resumec              	   C   sF  |du s| j sdS d}|ddd }|dddurN| j|d  |d }| jr|dr| jj|d    |d	 | j_|d
ksJ | jj	 d| j
 d| jj	 dtd| jj	 d|d  d| j
 d | j
|k rt| j	 d|d  d| j
 d |  j
|d 7  _
|| _|| _|| j
| jj krB|   dS )z7Resume YOLO training from given epoch and best fitness.Nr   r   r3   r/   r   rV   rS   r  r   z training to zf epochs is finished, nothing to resume.
Start a new training without resuming, i.e. 'yolo train model=r   zResuming training z from epoch z to z total epochsz has been trained for z epochs. Fine-tuning for z more epochs.)rB  ri   r   Zload_state_dictrS   floatr  r  r9   rO   rK   r   r{   rV   rL   r   r   )r\   r   rV   rL   r`   r`   ra   r     s2    
(zBaseTrainer.resume_trainingc                 C   sJ   t | jjdrd| jj_t | jjdrFtd | jjjt| jd dS )z5Update dataloaders to stop using mosaic augmentation.mosaicFr   zClosing dataloader mosaic)ZhypN)	r   r   r   rF  r   r{   r   r   r9   r   r`   r`   ra   r     s
    

z$BaseTrainer._close_dataloader_mosaicautoMbP??h㈵>     j@c                 C   s  g g g f}t dd tj D }|dkrttd d| jj d| jj	 d t
|dd	}	td
d|	  d}
|dkr|dnd|
df\}}}d| j_| D ]v\}}|jddD ]`\}}|r| d| n|}d|v r|d | qt||r|d | q|d | qq|dv rBt
t|tj|d ||dfdd}nR|dkrbtj|d ||d}n2|dkrtj|d ||dd }ntd!| d"||d |d# ||d dd# ttd d$t|j d%| d&| d't|d  d(t|d  d)| d*t|d  d+ |S ),a  
        Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
        weight decay, and number of iterations.

        Args:
            model (torch.nn.Module): The model for which to build an optimizer.
            name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
                based on the number of iterations. Default: 'auto'.
            lr (float, optional): The learning rate for the optimizer. Default: 0.001.
            momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
            decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
            iterations (float, optional): The number of iterations, which determines the optimizer if
                name is 'auto'. Default: 1e5.

        Returns:
            (torch.optim.Optimizer): The constructed optimizer.
        c                 s   s   | ]\}}d |v r|V  qdS )ZNormNr`   r
  r`   r`   ra   r     r   z.BaseTrainer.build_optimizer.<locals>.<genexpr>rG  z
optimizer:z' 'optimizer=auto' found, ignoring 'lr0=z' and 'momentum=zJ' and determining best 'optimizer', 'lr0' and 'momentum' automatically... r   
   {Gz?      i'  )SGDrM  rI  AdamWrI  r   F)recurser   Zbiasr8   r/   r   >   ZRAdamZAdamaxrQ  AdamZNAdamg+?)r   Zbetasr   ZRMSProp)r   r   rP  T)r   r   ZnesterovzOptimizer 'z' not found in list of available optimizers [Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.)paramsr   rn   z(lr=z, momentum=z) with parameter groups z weight(decay=0.0), z weight(decay=z), z bias(decay=0.0))rs   r   __dict__r  r   r{   r   r9   r   r   getattrr   r   Znamed_modulesr   rd   rp   r	   rS  ZRMSproprP  r/  Zadd_param_grouprM   __name__rq   )r\   rO   rB   r   r   r   r   gZbnr   Zlr_fitmodule_namemodule
param_nameparamfullnamer   r`   r`   ra   r     sZ    


$


"


zBaseTrainer.build_optimizer)r/   )NNT)rm   r   r   )r   N)Nr   )N)rG  rH  rI  rJ  rK  )*rW  
__module____qualname____doc__r   rb   rE   rg   rh   rj   r   r   r   r   r   r   r   r  r   rP   r   r   r   r   r&  r   r   r0  r   r   r3  r   r   r   r   r   rA  r   r:   r   r   r   r`   r`   r`   ra   r.   :   sL   "@']
 !

#




	
 r.   )Br`  r  r   rw   r}   r   r   r   r   r   r   pathlibr   r-  r   ru   r   r   r   r	   Zultralytics.cfgr
   r   Zultralytics.data.utilsr   r   Zultralytics.nn.tasksr   r   Zultralytics.utilsr   r   r   r   r   r   r   r   r   r   r   Zultralytics.utils.autobatchr   Zultralytics.utils.checksr   r   r   r   r    Zultralytics.utils.distr!   r"   Zultralytics.utils.filesr#   Zultralytics.utils.torch_utilsr$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   r`   r`   r`   ra   <module>   s.   40