1 pytorch 获取 mnist 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_get(): print(torch.__version__) # 定义数据转换 transform = transforms.compose([ transforms.totensor(), # 将图像转换为张量 transforms.normalize((0.5,), (0.5,)) # 归一化图像数据 ]) # 获取数据 train_data = datasets.mnist(root='./data', train=false, download=true, transform=transform) test_data = datasets.mnist(root='./data', train=false, download=true, transform=transform) # 训练数据 train_image = train_data.data.numpy() train_label = train_data.targets.numpy() # 测试数据 test_image = test_data.data.numpy() test_label = test_data.targets.numpy()
2 pytorch 保存 mnist 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_save(mnist_path): print(torch.__version__) # 定义数据转换 transform = transforms.compose([ transforms.totensor(), # 将图像转换为张量 transforms.normalize((0.5,), (0.5,)) # 归一化图像数据 ]) # 获取数据 train_data = datasets.mnist(root='./data', train=false, download=true, transform=transform) test_data = datasets.mnist(root='./data', train=false, download=true, transform=transform) # 训练数据 train_image = train_data.data.numpy() train_label = train_data.targets.numpy() # 测试数据 test_image = test_data.data.numpy() test_label = test_data.targets.numpy() np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label) mnist_path = 'c:\\users\\hyacinth\\desktop\\mnist.npz' mnist_save(mnist_path)
3 pytorch 显示 mnist 数据
import torch import numpy as np import matplotlib.pyplot as plt # type: ignore from torchvision import datasets, transforms def mnist_show(mnist_path): data = np.load(mnist_path) image = data['train_data'][0:100] label = data['train_label'].reshape(-1, ) plt.figure(figsize = (10, 10)) for i in range(100): print('%f, %f' % (i, label[i])) plt.subplot(10, 10, i + 1) plt.imshow(image[i]) plt.show() mnist_path = 'c:\\users\\hyacinth\\desktop\\mnist.npz' mnist_show(mnist_path)
到此这篇关于python pytorch 获取 mnist 数据的文章就介绍到这了,更多相关python pytorch 获取 mnist 数据内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论