35 lines
988 B
Python
35 lines
988 B
Python
#加载必要的库
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.optim as optim
|
||
from torchvision import datasets,transforms
|
||
#定义超参数
|
||
BATCH_SIZE=128#每批处理的数据
|
||
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')#用cpu还是gpu
|
||
EPOCHS=50#训练次数
|
||
#构建pipeline,对图像做处理
|
||
pipeline=transforms.Compose([
|
||
transforms.ToTensor(),#将图片转化成tensor
|
||
transforms.Normalize(0.1307,),(0,3081,)#正则化,降低模型复杂度
|
||
])
|
||
|
||
#下载加载数据
|
||
from torch.utils.data import DataLoader
|
||
|
||
#下载数据集
|
||
train_set=datasets.MNIST('data',train=True,download=True,transform=pipeline)
|
||
|
||
test_set=datasets.MNIST('data',train=False,download=False,transform=pipeline)
|
||
#构建网络模型
|
||
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
|
||
|
||
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
|
||
|
||
#定义优化器
|
||
|
||
#定义训练方法
|
||
|
||
#定义测试方法
|
||
|
||
#调用,输出 |