a
    PSicS                     @   s`  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlmZm	Z	m
Z
mZmZmZ d dlmZ d dlZd dlZd dlmZ ddlmZmZmZmZ ddlmZ G dd	 d	eZG d
d deZG dd deZG dd deZG dd deZ e!e"dddZ#ej$ej%ej&ej'ej(ej)dZ*de+e,ej-dddZ.e+ej-dddZ/e+ej-dddZ0dS )     N)AnyCallableDictListOptionalTuple)URLError)Image   )download_and_extract_archiveextract_archiveverify_str_argcheck_integrity)VisionDatasetc                       s0  e Zd ZdZddgZg dZdZdZg dZe	dd	 Z
e	d
d Ze	dd Ze	dd Zd.eeee ee edd fddZdd Zdd Zdd Zeeeef dddZedd d!Ze	edd"d#Ze	edd$d%Ze	eeef dd&d'Zedd(d)Zddd*d+Z edd,d-Z!  Z"S )/MNISTa]  `MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    z!http://yann.lecun.com/exdb/mnist/z.https://ossci-datasets.s3.amazonaws.com/mnist/))train-images-idx3-ubyte.gzZ f68b3c2dcbeaaa9fbdd348bbdeb94873)train-labels-idx1-ubyte.gzZ d53e105ee54ea40749a09fcbcd1e9432)t10k-images-idx3-ubyte.gzZ 9fb629c4189551a2d022fa330f9573f3)t10k-labels-idx1-ubyte.gzZ ec29112dd5afa0611ce80d1b7f02629cztraining.ptztest.pt
z0 - zeroz1 - onez2 - twoz	3 - threez4 - fourz5 - fivez6 - sixz	7 - sevenz	8 - eightz9 - ninec                 C   s   t d | jS )Nz%train_labels has been renamed targetswarningswarntargetsself r   V/var/www/html/django/DPS/env/lib/python3.9/site-packages/torchvision/datasets/mnist.pytrain_labels?   s    
zMNIST.train_labelsc                 C   s   t d | jS )Nz$test_labels has been renamed targetsr   r   r   r   r   test_labelsD   s    
zMNIST.test_labelsc                 C   s   t d | jS )Nz train_data has been renamed datar   r   datar   r   r   r   
train_dataI   s    
zMNIST.train_datac                 C   s   t d | jS )Nztest_data has been renamed datar    r   r   r   r   	test_dataN   s    
zMNIST.test_dataTNF)roottrain	transformtarget_transformdownloadreturnc                    sd   t  j|||d || _|  r4|  \| _| _d S |r@|   |  sPt	d| 
 \| _| _d S )N)r&   r'   z;Dataset not found. You can use download=True to download it)super__init__r%   _check_legacy_exist_load_legacy_datar!   r   r(   _check_existsRuntimeError
_load_data)r   r$   r%   r&   r'   r(   	__class__r   r   r+   S   s    zMNIST.__init__c                    s4   t j j}|sdS t fdd j jfD S )NFc                 3   s"   | ]}t tj j|V  qd S N)r   ospathjoinprocessed_folder.0filer   r   r   	<genexpr>o   s   z,MNIST._check_legacy_exist.<locals>.<genexpr>)r4   r5   existsr7   alltraining_file	test_file)r   Zprocessed_folder_existsr   r   r   r,   j   s    
zMNIST._check_legacy_existc                 C   s(   | j r| jn| j}ttj| j|S r3   )	r%   r>   r?   torchloadr4   r5   r6   r7   )r   	data_filer   r   r   r-   s   s    zMNIST._load_legacy_datac                 C   sX   | j r
dnd d}ttj| j|}| j r2dnd d}ttj| j|}||fS )Nr%   Zt10k-images-idx3-ubyte-labels-idx1-ubyte)r%   read_image_filer4   r5   r6   
raw_folderread_label_file)r   
image_filer!   Z
label_filer   r   r   r   r0   y   s
    zMNIST._load_dataindexr)   c                 C   s\   | j | t| j|  }}tj| dd}| jdur@| |}| jdurT| |}||fS )z
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        LmodeN)r!   intr   r	   	fromarraynumpyr&   r'   r   rJ   imgtargetr   r   r   __getitem__   s    



zMNIST.__getitem__r)   c                 C   s
   t | jS r3   )lenr!   r   r   r   r   __len__   s    zMNIST.__len__c                 C   s   t j| j| jjdS )Nrawr4   r5   r6   r$   r2   __name__r   r   r   r   rF      s    zMNIST.raw_folderc                 C   s   t j| j| jjdS )N	processedrY   r   r   r   r   r7      s    zMNIST.processed_folderc                 C   s   dd t | jD S )Nc                 S   s   i | ]\}}||qS r   r   )r9   i_classr   r   r   
<dictcomp>       z&MNIST.class_to_idx.<locals>.<dictcomp>)	enumerateclassesr   r   r   r   class_to_idx   s    zMNIST.class_to_idxc                    s   t  fdd jD S )Nc              
   3   s:   | ]2\}}t tj jtjtj|d  V  qdS )r   N)r   r4   r5   r6   rF   splitextbasename)r9   url_r   r   r   r;      s   z&MNIST._check_exists.<locals>.<genexpr>)r=   	resourcesr   r   r   r   r.      s    zMNIST._check_existsc                 C   s   |   rdS tj| jdd | jD ]\}}| jD ]}| | }zvz$td|  t|| j||d W nF ty } z.td|  W Y d}~W t  q0W Y d}~n
d}~0 0 W t  nt  0  q"q0t	d| q"dS )z4Download the MNIST data if it doesn't exist already.NTexist_okzDownloading )download_rootfilenamemd5z"Failed to download (trying next):
zError downloading )
r.   r4   makedirsrF   rg   mirrorsprintr   r   r/   )r   rk   rl   mirrorre   errorr   r   r   r(      s"    
zMNIST.downloadc                 C   s   | j du rdnd}d| S )NTTrainTestSplit: )r%   )r   splitr   r   r   
extra_repr   s    zMNIST.extra_repr)TNNF)#rZ   
__module____qualname____doc__rn   rg   r>   r?   ra   propertyr   r   r"   r#   strboolr   r   r+   r,   r-   r0   rN   r   r   rT   rW   rF   r7   r   rb   r.   r(   rv   __classcell__r   r   r1   r   r      sT   



    		r   c                   @   s&   e Zd ZdZdgZg dZg dZdS )FashionMNISTa  `Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte``
            and  ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    z;http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/))r   Z 8d4fb7e6c68d591d4c3dfef9ec88bf0d)r   Z 25c81989df183df01b3e8a0aad5dffbe)r   Z bef4ecab320f06d8554ea6380940ec79)r   Z bb300cfdad3c16e7a12a480ee83cd310)
zT-shirt/topZTrouserZPulloverZDressZCoatZSandalZShirtZSneakerZBagz
Ankle bootNrZ   rw   rx   ry   rn   rg   ra   r   r   r   r   r~      s   r~   c                   @   s&   e Zd ZdZdgZg dZg dZdS )KMNISTak  `Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte``
            and  ``KMNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    z-http://codh.rois.ac.jp/kmnist/dataset/kmnist/))r   Z bdb82020997e1d708af4cf47b453dcf7)r   Z e144d726b3acfaa3e44228e80efcd344)r   Z 5c965bf0a639b31b8f53240b1b52f4d7)r   Z 7320c461ea6c1c855c0b718fb2a4b134)
okiZsuZtsunahamaZyarewoNr   r   r   r   r   r      s   r   c                       s  e Zd ZdZdZdZdZh dZee	j
e	j Zeeeeeee eeee dgee	j ee	j
ee	j
dZeeedd fd	d
ZeedddZeedddZeedddZeedddZeedddZdd ZedddZddddZ  ZS )EMNISTa8  `EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte``
            and  ``EMNIST/raw/t10k-images-idx3-ubyte`` exist.
        split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
            ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
            which one to use.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    z:https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zipZ 58c8d27c78d21e728a6bc7b3cc06412e)ZbyclassZbymergeZbalancedlettersdigitsmnist>   mr   wzxr\   plcsvyjkuzN/AN)r$   ru   kwargsr)   c                    sN   t |d| j| _| || _| || _t j|fi | | j	| j | _
d S )Nru   )r   splitsru   _training_filer>   
_test_filer?   r*   r+   classes_split_dictra   )r   r$   ru   r   r1   r   r   r+   %  s
    zEMNIST.__init__rU   c                 C   s   d|  dS )NZ	training_.ptr   ru   r   r   r   r   ,  s    zEMNIST._training_filec                 C   s   d|  dS )Ntest_r   r   r   r   r   r   r   0  s    zEMNIST._test_filec                 C   s   d| j  d| jrdnd S )Nzemnist--r%   test)ru   r%   r   r   r   r   _file_prefix4  s    zEMNIST._file_prefixc                 C   s   t j| j| j dS )NrC   r4   r5   r6   rF   r   r   r   r   r   images_file8  s    zEMNIST.images_filec                 C   s   t j| j| j dS )NrD   r   r   r   r   r   labels_file<  s    zEMNIST.labels_filec                 C   s   t | jt| jfS r3   )rE   r   rG   r   r   r   r   r   r0   @  s    zEMNIST._load_datac                 C   s   t dd | j| jfD S )Nc                 s   s   | ]}t |V  qd S r3   r   r8   r   r   r   r;   D  r_   z'EMNIST._check_exists.<locals>.<genexpr>r=   r   r   r   r   r   r   r.   C  s    zEMNIST._check_existsc                 C   s~   |   rdS tj| jdd t| j| j| jd tj| jd}t	|D ]$}|
drJttj||| j qJt| dS )z5Download the EMNIST data if it doesn't exist already.NTrh   )rj   rl   gzipz.gz)r.   r4   rm   rF   r   re   rl   r5   r6   listdirendswithr   shutilrmtree)r   Zgzip_folderZ	gzip_filer   r   r   r(   F  s    
zEMNIST.download) rZ   rw   rx   ry   re   rl   r   Z_merged_classessetstringr   ascii_lettersZ_all_classessortedlistascii_lowercaser   r{   r   r+   staticmethodr   r   rz   r   r   r   r0   r|   r.   r(   r}   r   r   r1   r   r     s6   
	r   c                       s   e Zd ZU dZddddddZddgdd	gd
dgdZeeee	eef  f e
d< g dZd$eee eeedd fddZeedddZeedddZedddZdd ZddddZee	eef dd d!Zedd"d#Z  ZS )%QMNISTaP  `QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

    Args:
        root (string): Root directory of dataset whose ``raw``
            subdir contains binary files of the datasets.
        what (string,optional): Can be 'train', 'test', 'test10k',
            'test50k', or 'nist' for respectively the mnist compatible
            training set, the 60k qmnist testing set, the 10k qmnist
            examples that match the mnist testing set, the 50k
            remaining qmnist testing examples, or all the nist
            digits. The default is to select 'train' or 'test'
            according to the compatibility argument 'train'.
        compat (bool,optional): A boolean that says whether the target
            for each example is class number (for compatibility with
            the MNIST dataloader) or a torch vector containing the
            full qmnist information. Default=True.
        download (bool, optional): If True, downloads the dataset from
            the internet and puts it in root directory. If dataset is
            already downloaded, it is not downloaded again.
        transform (callable, optional): A function/transform that
            takes in an PIL image and returns a transformed
            version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform
            that takes in the target and transforms it.
        train (bool,optional,compatibility): When argument 'what' is
            not specified, this boolean decides whether to load the
            training set ot the testing set.  Default: True.
    r%   r   nist)r%   r   test10ktest50kr   )zbhttps://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gzZ ed72d4157d28c017586c42bc6afe6370)z`https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gzZ 0058f8dd561b90ffdd0f734c6a30e5e4)zahttps://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gzZ 1394631089c404de565df7b7aeaf9412)z_https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gzZ 5b5b05890a5e13444e108efe57b788aa)z[https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xzZ 7f124b3b8ab81486c9d8c2749c17f834)zYhttps://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xzZ 5ed0e788978e45d4a8bd4b7caec3d79d)r%   r   r   rg   r   NT)r$   whatcompatr%   r   r)   c                    sf   |d u r|rdnd}t |dt| j | _|| _|d | _| j| _| j| _t	 j
||fi | d S )Nr%   r   r   r   )r   tuplesubsetskeysr   r   rB   r>   r?   r*   r+   )r   r$   r   r   r%   r   r1   r   r   r+     s    
zQMNIST.__init__rU   c                 C   s>   | j | j| j  \\}}}tj| jtjtj|d S Nr   	rg   r   r   r4   r5   r6   rF   rc   rd   )r   re   rf   r   r   r   r     s    zQMNIST.images_filec                 C   s>   | j | j| j  \}\}}tj| jtjtj|d S r   r   )r   rf   re   r   r   r   r     s    zQMNIST.labels_filec                 C   s   t dd | j| jfD S )Nc                 s   s   | ]}t |V  qd S r3   r   r8   r   r   r   r;     r_   z'QMNIST._check_exists.<locals>.<genexpr>r   r   r   r   r   r.     s    zQMNIST._check_existsc                 C   s   t | j}|jtjkr&td|j | dkr:tdt | j	 }| dkrftd|  | j
dkr|ddd d d d f  }|ddd d f  }n@| j
d	kr|dd d d d d f  }|dd d d f  }||fS )
Nz/data should be of dtype torch.uint8 instead of    z<data should have 3 dimensions instead of {data.ndimension()}   z,targets should have 2 dimensions instead of r   r   i'  r   )read_sn3_pascalvincent_tensorr   dtyper@   uint8	TypeError
ndimension
ValueErrorr   longr   clone)r   r!   r   r   r   r   r0     s    


zQMNIST._load_datac                 C   sP   |   rdS tj| jdd | j| j| j  }|D ]\}}t|| j|d q2dS )zDownload the QMNIST data if it doesn't exist already.
        Note that we only download what has been asked for (argument 'what').
        NTrh   )rl   )r.   r4   rm   rF   rg   r   r   r   )r   ru   re   rl   r   r   r   r(     s    zQMNIST.downloadrI   c                 C   sj   | j | | j|  }}tj| dd}| jd ur<| |}| jrNt|d }| jd urb| |}||fS )NrK   rL   r   )	r!   r   r	   rO   rP   r&   r   rN   r'   rQ   r   r   r   rT     s    



zQMNIST.__getitem__c                 C   s   d| j  S )Nrt   )r   r   r   r   r   rv     s    zQMNIST.extra_repr)NTT)rZ   rw   rx   ry   r   rg   r   r{   r   r   __annotations__ra   r   r|   r   r+   rz   r   r   r.   r0   r(   rN   rT   rv   r}   r   r   r1   r   r   V  s6   
"  r   )br)   c                 C   s   t t| ddS )Nhex   )rN   codecsencode)r   r   r   r   get_int  s    r   )   	               T)r5   strictr)   c                    s  t | d}|  W d   n1 s(0    Y  t dd }|d }|d }d|  krfdksln J d|  krd	ksn J t| } fd
dt|D }t|jd }tj	dko|dk}	tj
t |d|d  d}
|	r|
d}
|
jd t|ks|rJ |
j| S )zRead a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
    Argument may be a filename, compressed filename, or file object.
    rbNr         r
   r   r   r   c                    s,   g | ]$}t  d |d  d |d   qS )r   r
   r   )r   )r9   r\   r!   r   r   
<listcomp>  r_   z1read_sn3_pascalvincent_tensor.<locals>.<listcomp>little)r   offset)openreadr   SN3_PASCALVINCENT_TYPEMAPranger@   iinfobitssys	byteorder
frombuffer	bytearrayflipshapenpprodview)r5   r   fmagicndtyZ
torch_typer   Znum_bytes_per_valueZneeds_byte_reversalparsedr   r   r   r     s     &
 r   )r5   r)   c                 C   sN   t | dd}|jtjkr(td|j | dkrFtd|  | S )NFr   ,x should be of dtype torch.uint8 instead of r
   z%x should have 1 dimension instead of )r   r   r@   r   r   r   r   r   r5   r   r   r   r   rG     s    rG   c                 C   sJ   t | dd}|jtjkr(td|j | dkrFtd|  |S )NFr   r   r   z%x should have 3 dimension instead of )r   r   r@   r   r   r   r   r   r   r   r   rE     s    rE   )T)1r   r4   os.pathr   r   r   r   typingr   r   r   r   r   r   urllib.errorr   rP   r   r@   PILr	   utilsr   r   r   r   visionr   r   r~   r   r   r   bytesrN   r   r   int8int16int32float32float64r   r{   r|   Tensorr   rG   rE   r   r   r   r   <module>   s>     8T 
	