上传文件至 /

This commit is contained in:
kayoko 2024-10-18 16:18:11 +00:00
commit 2c331cf818

35
minst.py Normal file
View File

@ -0,0 +1,35 @@
#加载必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
#定义超参数
BATCH_SIZE=128#每批处理的数据
DEVICE=torch.device('cuda' if torch.cuda.is_available() else 'cpu')#用cpu还是gpu
EPOCHS=50#训练次数
#构建pipeline对图像做处理
pipeline=transforms.Compose([
transforms.ToTensor(),#将图片转化成tensor
transforms.Normalize(0.1307,),(0,3081,)#正则化,降低模型复杂度
])
#下载加载数据
from torch.utils.data import DataLoader
#下载数据集
train_set=datasets.MNIST('data',train=True,download=True,transform=pipeline)
test_set=datasets.MNIST('data',train=False,download=False,transform=pipeline)
#构建网络模型
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
#定义优化器
#定义训练方法
#定义测试方法
#调用,输出