while/main.py

122 lines
4.0 KiB
Python
Raw Normal View History

2024-10-25 21:16:34 +08:00
#导入模块
import torch
import torchvision
2024-10-29 20:05:56 +08:00
import random
from enum import Enum
import matplotlib as plt
import torch.nn.functional as F
2024-10-25 21:16:34 +08:00
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#使用torchvisiom.transform将图片转换为张量
2024-10-29 20:05:56 +08:00
2024-10-25 21:16:34 +08:00
transform =torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.5],std=[0.5])])
#构建数据集使用MINIST
path ='./MNIST'
2024-10-29 20:05:56 +08:00
EPOCH =10
Batch_Size = 64
#创建dataset
class DataSet:
def __init__(self,floder,data_name,label_name,transform=None):
self.floder = floder
self.transform = transform
#创建dataloader
class DataLoader:
def __init__(self,type,batch_size,is_shuffle):
data_name = 'train-images-idx3-ubyte.gz' if type=='train' else 't10k-images-idx3-ubyte.gz'
label_name = 'train-labels-idx1-ubyte.gz' if type=='train' else 't10k-labels-idx1-ubyte.gz'
2024-10-25 21:16:34 +08:00
#下载数据集
#下载训练集
trainData=torchvision.datasets.MNIST(path,train=True,transform=transform,download=True)
#下载测试集
testData=torchvision.datasets.MNIST(path,train=False,transform=transform,download=False)
2024-10-28 19:58:07 +08:00
#shuffle=true是用来打乱数据集的
2024-10-29 20:05:56 +08:00
train_Dataloader = DataLoader(train,batch_size = Batch_Size,shuffle = True)
test_DataLoader = DataLoader(test,batch_size=Batch_Size,shuffle=False)
2024-10-25 21:16:34 +08:00
class Net(torch.nn.Module):
#构造函数
def __init__(self):
#继承父类
super(Net,self).__init__()
2024-10-29 20:05:56 +08:00
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 10, kernel_size=5),
2024-10-26 10:00:14 +08:00
torch.nn.ReLU(),
2024-10-29 20:05:56 +08:00
torch.nn.MaxPool2d(kernel_size=2),
2024-10-28 19:58:07 +08:00
)
2024-10-29 20:05:56 +08:00
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(10, 20, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(320, 50),
torch.nn.Linear(50, 10),
2024-10-25 21:16:34 +08:00
)
2024-10-28 19:58:07 +08:00
def forward(self,x):
2024-10-29 20:05:56 +08:00
batch_size = x.size(0)
x = self.conv1(x) # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
x = self.conv2(x) # 再来一次
x = x.view(batch_size, -1) # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320
x = self.fc(x)
2024-10-28 19:58:07 +08:00
return x
2024-10-29 20:05:56 +08:00
model = Net().to(device)
2024-10-28 19:58:07 +08:00
#使用交叉墒损失做损失函数
sunshi = torch.nn.CrossEntropyLoss()
#优化器:随机梯度下降
#lr=学习率momentum = 冲量
optimizer = torch.optim.SGD(model.parameters(),lr=0.25,momentum=0.25)
#训练
def train(epoch):
running_loss = 0.0
running_total=0
running_correct = 0
for batch_idx,data in enumerate(train_Dataloader,0):
2024-10-29 20:05:56 +08:00
inputs, target = data
inputs = inputs.to(device)
target = target.to(device)
2024-10-28 19:58:07 +08:00
#梯度归零
optimizer.zero_grad()
outputs = model(inputs)
2024-10-29 20:05:56 +08:00
loss = sunshi(outputs,target)
2024-10-28 19:58:07 +08:00
#反向传播
2024-10-25 21:16:34 +08:00
loss.backward()
optimizer.step()
2024-10-28 19:58:07 +08:00
running_loss += loss.item()
#准确率
2024-10-29 20:05:56 +08:00
_,predicted = torch.max(outputs,dim=1)
2024-10-28 19:58:07 +08:00
running_total+=inputs.shape[0]
2024-10-29 20:05:56 +08:00
running_correct += (predicted == target).sum().item()
2024-10-28 19:58:07 +08:00
print('[%d,%5d]:loss:%.3f,acc:%.2f',epoch+1,batch_idx+1,running_loss,running_correct/running_total)
#测试
def test():
correct =0
total = 0
with torch.no_grad():
for data in test_DataLoader:
images,labels = data
outputs = model(images)
predicted = torch.max(outputs.data,dim=1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()
accuracy = correct/total
print('[%d/%d]Accuracy: %.lf %%', epoch+1,EPOCH,accuracy)
return accuracy
2024-10-29 20:05:56 +08:00
if __name__ =='__main__':
2024-10-28 19:58:07 +08:00
acc_list_test =[]
for epoch in range(EPOCH):
train(epoch)
#每训练10轮测试一次
if epoch % 10 ==9:
acc_test = test()
2024-10-29 20:05:56 +08:00
acc_list_test.append(acc_test)
plt.plot(acc_list_test)
plt.xlabel('Epoch')
plt.ylabel('Accuracy On TestSet')
plt.show()