Python PyTorch 获取 MNIST 数据
1 PyTorch 获取 MNIST 数据 2 PyTorch 保存 MNIST 数据 3 PyTorch 显示 MNIST 数据
1 PyTorch 获取 MNIST 数据
import torch
import numpy as np
import matplotlib. pyplot as plt
from torchvision import datasets, transforms
def mnist_get ( ) :
print ( torch. __version__)
transform = transforms. Compose( [
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , ) , ( 0.5 , ) )
] )
train_data = datasets. MNIST( root= './data' , train= False , download= True , transform= transform)
test_data = datasets. MNIST( root= './data' , train= False , download= True , transform= transform)
train_image = train_data. data. numpy( )
train_label = train_data. targets. numpy( )
test_image = test_data. data. numpy( )
test_label = test_data. targets. numpy( )
2 PyTorch 保存 MNIST 数据
import torch
import numpy as np
import matplotlib. pyplot as plt
from torchvision import datasets, transforms
def mnist_save ( mnist_path) :
print ( torch. __version__)
transform = transforms. Compose( [
transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , ) , ( 0.5 , ) )
] )
train_data = datasets. MNIST( root= './data' , train= False , download= True , transform= transform)
test_data = datasets. MNIST( root= './data' , train= False , download= True , transform= transform)
train_image = train_data. data. numpy( )
train_label = train_data. targets. numpy( )
test_image = test_data. data. numpy( )
test_label = test_data. targets. numpy( )
np. savez( mnist_path, train_data= train_image, train_label= train_label, test_data= test_image, test_label= test_label)
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save( mnist_path)
3 PyTorch 显示 MNIST 数据
import torch
import numpy as np
import matplotlib. pyplot as plt
from torchvision import datasets, transforms
def mnist_show ( mnist_path) :
data = np. load( mnist_path)
image = data[ 'train_data' ] [ 0 : 100 ]
label = data[ 'train_label' ] . reshape( - 1 , )
plt. figure( figsize = ( 10 , 10 ) )
for i in range ( 100 ) :
print ( '%f, %f' % ( i, label[ i] ) )
plt. subplot( 10 , 10 , i + 1 )
plt. imshow( image[ i] )
plt. show( )
mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show( mnist_path)