while/main.py
kemna 8dd08b8fbf
Some checks failed
/ job1 (push) Failing after 5m47s
bug一堆
2024-10-29 20:18:54 +08:00

169 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#导入模块
import torch
import os
from torchvision import transforms
import random
import gzip
import numpy as np
from enum import Enum
from matplotlib import pyplot as plt
import torch.nn.functional as F
from PIL import Image
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#使用torchvisiom.transform将图片转换为张量
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
#构建数据集使用MINIST
path ='./MNIST/raw'
EPOCH =10
Batch_Size = 64
#创建dataset
class Dataset:
def __init__(self, data_path, train=True):
self.data_path = data_path
self.train = train
self.images = []
self.labels = []
self.load_data()
def load_data(self):
if self.train:
file_name = 'train-images-idx3-ubyte.gz'
label_file_name = 'train-labels-idx1-ubyte.gz'
else:
file_name = 't10k-images-idx3-ubyte.gz'
label_file_name = 't10k-labels-idx1-ubyte.gz'
with gzip.open(os.path.join(self.data_path, file_name), 'rb') as f:
self.images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
with gzip.open(os.path.join(self.data_path, label_file_name), 'rb') as f:
self.labels = np.frombuffer(f.read(), np.uint8, offset=8)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
#创建dataloader
class DataLoader:
def __init__(self,dataset,batch_size,is_shuffle):
self.dataset = dataset
self.batch_size = batch_size
self.is_shuffle = is_shuffle
self.data_indices = list(range(len(dataset)))
#打乱数据集
def shuffle_data(self):
if self.is_shuffle:
random.shuffle(self.data_indices)
#迭代数据集
def __iter__(self):
self.shuffle_data()
for i in range(0, len(self.dataset), self.batch_size):
batch_indices = self.data_indices[i:i + self.batch_size]
batch_images = []
batch_labels = []
for idx in batch_indices:
image, label = self.dataset[idx]
batch_images.append(image)
batch_labels.append(label)
yield (batch_images, batch_labels)
#下载数据集
#下载训练集
trainData=Dataset(path,train=True)
#下载测试集
testData=Dataset(path,train=False)
#shuffle=true是用来打乱数据集的
train_Dataloader = DataLoader(trainData,batch_size=Batch_Size,is_shuffle=True)
test_DataLoader = DataLoader(testData,batch_size=Batch_Size,is_shuffle=False)
class Net(torch.nn.Module):
#构造函数
def __init__(self):
#继承父类
super(Net,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 10, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(10, 20, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(320, 50),
torch.nn.Linear(50, 10),
)
def forward(self,x):
batch_size = x.size(0)
x = self.conv1(x) # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
x = self.conv2(x) # 再来一次
x = x.view(batch_size, -1) # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320
x = self.fc(x)
return x
EPOCH = 100
model = Net().to(device)
#使用交叉墒损失做损失函数
sunshi = torch.nn.CrossEntropyLoss()
#优化器:随机梯度下降
#lr=学习率momentum = 冲量
optimizer = torch.optim.SGD(model.parameters(),lr=0.25,momentum=0.25)
#训练
def train(epoch):
running_loss = 0.0
running_total=0
running_correct = 0
for batch_idx, (inputs, target) in enumerate(train_Dataloader, 0):
inputs = torch.stack([torch.tensor(img, dtype=torch.float) for img in inputs]).to(device)
target = torch.tensor(target, dtype=torch.long).to(device)
#梯度归零
optimizer.zero_grad()
outputs = model(inputs)
loss = sunshi(outputs,target)
#反向传播
loss.backward()
optimizer.step()
running_loss += loss.item()
#准确率
_,predicted = torch.max(outputs,dim=1)
running_total+=inputs.shape[0]
running_correct += (predicted == target).sum().item()
print('[%d,%5d]:loss:%.3f,acc:%.2f',epoch+1,batch_idx+1,running_loss,running_correct/running_total)
#测试
def test():
correct =0
total = 0
with torch.no_grad():
for data in test_DataLoader:
images,labels = data
outputs = model(images)
predicted = torch.max(outputs.data,dim=1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()
accuracy = correct/total
print('[%d/%d]Accuracy: %.lf %%', epoch+1,EPOCH,accuracy)
return accuracy
if __name__ =='__main__':
acc_list_test =[]
for epoch in range(EPOCH):
train(epoch)
#每训练10轮测试一次
if epoch % 10 ==9:
acc_test = test()
acc_list_test.append(acc_test)
plt.plot(acc_list_test)
plt.xlabel('Epoch')
plt.ylabel('Accuracy On TestSet')
plt.show()