mnist/minst.py
2024-10-18 16:18:11 +00:00

35 lines
988 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
#定义超参数
BATCH_SIZE=128#每批处理的数据
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')#用cpu还是gpu
EPOCHS=50#训练次数
#构建pipeline对图像做处理
pipeline=transforms.Compose([
transforms.ToTensor(),#将图片转化成tensor
transforms.Normalize(0.1307,),(0,3081,)#正则化,降低模型复杂度
])
#下载加载数据
from torch.utils.data import DataLoader
#下载数据集
train_set=datasets.MNIST('data',train=True,download=True,transform=pipeline)
test_set=datasets.MNIST('data',train=False,download=False,transform=pipeline)
#构建网络模型
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
#定义优化器
#定义训练方法
#定义测试方法
#调用,输出