在MATPool矩池云完成Pytorch训练MNIST数据集


本文为矩池云入门手册的补充:Pytorch训练MNIST数据集代码运行过程。

案例代码和对应数据集,以及在矩池云上的详细操作可以在矩池云入门手册中查看,本文基于矩池云入门手册,默认用户已经完成了机器租用,上传解压好了数据、代码,并使用jupyter lab进行代码运行。

在MATPool矩池云完成Pytorch训练MNIST数据集

1. 安装自己需要的第三方包

以tqdm包为例子,如果在运行代码过程出现了ModuleNotFoundError: No module named 'tqdm',说明我们选择的系统镜像中没有预装这个包,我们只需要再JupyterLab的Terminal输入pip install tqdm即可安装相关包。

在MATPool矩池云完成Pytorch训练MNIST数据集

其他自己需要的第三方包安装方法也类似。

2. 在JupyterLab中运行代码

JupyterLab目录里面,我们依次点击mnt->MyMNIST进入到项目文件夹,在项目文件夹下双击pytorch_mnist.ipynb文件,即可打开代码文件。

在MATPool矩池云完成Pytorch训练MNIST数据集

打开代码文件后,我们就可以直接运行了,截图中给大家说明了几个常用的JupyteLab 按钮功能。

在MATPool矩池云完成Pytorch训练MNIST数据集

接下来我们开始运行代码~

2.1 导入需要的Python包

首先运行下面代码导入需要的模块,如:

  • pytorch相关:torch、torchvision
  • 训练输出进度条可视化显示:tqdm
  • 训练结果图表可视化显示:matplotlib.pyplot
# 导入相关包
# 测试环境 K80 pytorch1.10
import torch
import torchvision 
from tqdm import tqdm
import matplotlib.pyplot as plt

测试下机器中的pytorch版本和GPU是否可用。

# 查看pytorch版本和gpu是否可用
print(torch.__version__)
print(torch.cuda.is_available())

'''
输出:
1.10.0+cu113
True 
'''

上面输出表示pytorch版本为1.10.0,机器GPU可用。

2.2 数据预处理
设置device、BATCH_SIZE和EPOCHS

# 如果网络能在GPU中训练,就使用GPU;否则使用CPU进行训练
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# 这个函数包括了两个操作:将图片转换为张量,以及将图片进行归一化处理
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                torchvision.transforms.Normalize(mean = [0.5],std = [0.5])])
                                
# 设置了每个包中的图片数据个数
BATCH_SIZE = 64
EPOCHS = 10

加载构建训练和测试数据集

# 从项目文件中加载训练数据和测试数据
train_dataset = torchvision.datasets.MNIST('/mnt/MyMNIST/',train = True,transform = transform)
test_dataset = torchvision.datasets.MNIST('/mnt/MyMNIST/',train = False,transform = transform)

# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=True)

2.3 构建数据训练模型并创建实例
构建数据训练模型

# 一个简单的卷积神经网络
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.model = torch.nn.Sequential(
            #The size of the picture is 28x28
            torch.nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            #The size of the picture is 14x14
            torch.nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size = 2,stride = 2),
            
            #The size of the picture is 7x7
            torch.nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),
            torch.nn.ReLU(),
            
            torch.nn.Flatten(),
            torch.nn.Linear(in_features = 7 * 7 * 64,out_features = 128),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features = 128,out_features = 10),
            torch.nn.Softmax(dim=1)
        )
        
    def forward(self,input):
        output = self.model(input)
        return output

构建模型实例

# 构建模型实例
net = Net()
# 将模型转换到device中,并将其结构显示出来
print(net.to(device))

在MATPool矩池云完成Pytorch训练MNIST数据集

2.4 构建迭代器与损失函数

# 交叉熵损失来作为损失函数
# Adam迭代器
loss_fun = torch.nn.CrossEntropyLoss() 
optimizer = torch.optim.Adam(net.parameters())

2.5 构建并运行训练循环

history = {'Test Loss':[],'Test Accuracy':[]}
for epoch in range(1,EPOCHS + 1):
    process_bar = tqdm(train_loader,unit = 'step')
    net.train(True)
    for step,(train_imgs,labels) in enumerate(process_bar):
        train_imgs = train_imgs.to(device)
        labels = labels.to(device)

        net.zero_grad()
        outputs = net(train_imgs)
        loss = loss_fun(outputs,labels)
        predictions = torch.argmax(outputs, dim = 1)
        accuracy = torch.true_divide(torch.sum(predictions == labels), labels.shape[0])
        loss.backward()

        optimizer.step()
        process_bar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f" % 
                                   (epoch,EPOCHS,loss.item(),accuracy.item()))

        if step == len(process_bar)-1:
            correct,total_loss = 0,0
            net.train(False)
            with torch.no_grad():
                for test_imgs,labels in test_loader:
                    test_imgs = test_imgs.to(device)
                    labels = labels.to(device)
                    outputs = net(test_imgs)
                    loss = loss_fun(outputs,labels)
                    predictions = torch.argmax(outputs,dim = 1)

                    total_loss += loss
                    correct += torch.sum(predictions == labels)

                test_accuracy = torch.true_divide(correct, (BATCH_SIZE * len(test_loader)))
                test_loss = torch.true_divide(total_loss, len(test_loader))
                history['Test Loss'].append(test_loss.item())
                history['Test Accuracy'].append(test_accuracy.item())

            process_bar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f" % 
                                   (epoch,EPOCHS,loss.item(),accuracy.item(),test_loss.item(),test_accuracy.item()))
    process_bar.close()

在MATPool矩池云完成Pytorch训练MNIST数据集

2.6 训练结果可视化

#对测试Loss进行可视化
plt.plot(history['Test Loss'],label = 'Test Loss')
plt.legend(loc='best')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

#对测试准确率进行可视化
plt.plot(history['Test Accuracy'],color = 'red',label = 'Test Accuracy')
plt.legend(loc='best')
plt.grid(True)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()

在MATPool矩池云完成Pytorch训练MNIST数据集

2.7 保存模型

# 保存训练好的模型
torch.save(net,'/mnt/MyMNIST/torch_mnist_model.pth')

保存成功后,JupyterLab 中对应文件夹会出现该文件,在矩池云网盘对应目录下也会存在。

在MATPool矩池云完成Pytorch训练MNIST数据集
在MATPool矩池云完成Pytorch训练MNIST数据集

参考文章

原创文章,作者:kepupublish,如若转载,请注明出处:https://blog.ytso.com/tech/pnotes/245590.html

(0)
上一篇 2022年4月18日 14:48
下一篇 2022年4月18日 14:52

相关推荐

发表回复

登录后才能评论