上传文件至 /
This commit is contained in:
commit
2c331cf818
35
minst.py
Normal file
35
minst.py
Normal file
@ -0,0 +1,35 @@
|
||||
#加载必要的库
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets,transforms
|
||||
#定义超参数
|
||||
BATCH_SIZE=128#每批处理的数据
|
||||
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')#用cpu还是gpu
|
||||
EPOCHS=50#训练次数
|
||||
#构建pipeline,对图像做处理
|
||||
pipeline=transforms.Compose([
|
||||
transforms.ToTensor(),#将图片转化成tensor
|
||||
transforms.Normalize(0.1307,),(0,3081,)#正则化,降低模型复杂度
|
||||
])
|
||||
|
||||
#下载加载数据
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
#下载数据集
|
||||
train_set=datasets.MNIST('data',train=True,download=True,transform=pipeline)
|
||||
|
||||
test_set=datasets.MNIST('data',train=False,download=False,transform=pipeline)
|
||||
#构建网络模型
|
||||
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
|
||||
|
||||
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
|
||||
|
||||
#定义优化器
|
||||
|
||||
#定义训练方法
|
||||
|
||||
#定义测试方法
|
||||
|
||||
#调用,输出
|
Loading…
x
Reference in New Issue
Block a user