while/main.py

170 lines
5.6 KiB
Python
Raw Permalink Normal View History

2024-10-25 21:16:34 +08:00
#导入模块
import torch
2024-10-29 20:18:54 +08:00
import os
from torchvision import transforms
2024-10-29 20:05:56 +08:00
import random
2024-10-29 20:18:54 +08:00
import gzip
import numpy as np
2024-10-29 20:05:56 +08:00
from enum import Enum
2024-10-29 20:18:54 +08:00
from matplotlib import pyplot as plt
2024-10-29 20:05:56 +08:00
import torch.nn.functional as F
2024-10-29 20:18:54 +08:00
from PIL import Image
2024-10-25 21:16:34 +08:00
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#使用torchvisiom.transform将图片转换为张量
2024-10-29 20:05:56 +08:00
2024-10-29 20:18:54 +08:00
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
2024-10-25 21:16:34 +08:00
#构建数据集使用MINIST
2024-10-29 20:18:54 +08:00
path ='./MNIST/raw'
2024-10-29 20:05:56 +08:00
EPOCH =10
2024-10-29 21:33:53 +08:00
Batch_Size = 10
2024-10-29 20:05:56 +08:00
#创建dataset
2024-10-29 20:18:54 +08:00
class Dataset:
def __init__(self, data_path, train=True):
self.data_path = data_path
self.train = train
self.images = []
self.labels = []
self.load_data()
def load_data(self):
if self.train:
file_name = 'train-images-idx3-ubyte.gz'
label_file_name = 'train-labels-idx1-ubyte.gz'
else:
file_name = 't10k-images-idx3-ubyte.gz'
label_file_name = 't10k-labels-idx1-ubyte.gz'
with gzip.open(os.path.join(self.data_path, file_name), 'rb') as f:
self.images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
with gzip.open(os.path.join(self.data_path, label_file_name), 'rb') as f:
self.labels = np.frombuffer(f.read(), np.uint8, offset=8)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
2024-10-29 20:05:56 +08:00
#创建dataloader
class DataLoader:
2024-10-29 20:18:54 +08:00
def __init__(self,dataset,batch_size,is_shuffle):
self.dataset = dataset
self.batch_size = batch_size
self.is_shuffle = is_shuffle
self.data_indices = list(range(len(dataset)))
#打乱数据集
def shuffle_data(self):
if self.is_shuffle:
random.shuffle(self.data_indices)
#迭代数据集
def __iter__(self):
self.shuffle_data()
for i in range(0, len(self.dataset), self.batch_size):
batch_indices = self.data_indices[i:i + self.batch_size]
batch_images = []
batch_labels = []
for idx in batch_indices:
image, label = self.dataset[idx]
batch_images.append(image)
batch_labels.append(label)
yield (batch_images, batch_labels)
2024-10-29 21:33:53 +08:00
torch.tensor(batch_labels)
2024-10-25 21:16:34 +08:00
#下载数据集
#下载训练集
2024-10-29 20:18:54 +08:00
trainData=Dataset(path,train=True)
2024-10-25 21:16:34 +08:00
#下载测试集
2024-10-29 20:18:54 +08:00
testData=Dataset(path,train=False)
2024-10-28 19:58:07 +08:00
#shuffle=true是用来打乱数据集的
2024-10-29 20:18:54 +08:00
train_Dataloader = DataLoader(trainData,batch_size=Batch_Size,is_shuffle=True)
test_DataLoader = DataLoader(testData,batch_size=Batch_Size,is_shuffle=False)
2024-10-25 21:16:34 +08:00
class Net(torch.nn.Module):
#构造函数
def __init__(self):
#继承父类
super(Net,self).__init__()
2024-10-29 20:05:56 +08:00
self.conv1 = torch.nn.Sequential(
2024-10-29 21:33:53 +08:00
torch.nn.Conv2d(10, 20, kernel_size=5),
2024-10-26 10:00:14 +08:00
torch.nn.ReLU(),
2024-10-29 20:05:56 +08:00
torch.nn.MaxPool2d(kernel_size=2),
2024-10-28 19:58:07 +08:00
)
2024-10-29 20:05:56 +08:00
self.conv2 = torch.nn.Sequential(
2024-10-29 21:33:53 +08:00
torch.nn.Conv2d(20, 40, kernel_size=5),
2024-10-29 20:05:56 +08:00
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.fc = torch.nn.Sequential(
2024-10-29 21:33:53 +08:00
torch.nn.Linear(64,10),
2024-10-25 21:16:34 +08:00
)
2024-10-28 19:58:07 +08:00
def forward(self,x):
2024-10-29 20:05:56 +08:00
batch_size = x.size(0)
x = self.conv1(x) # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
x = self.conv2(x) # 再来一次
x = x.view(batch_size, -1) # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320
x = self.fc(x)
2024-10-28 19:58:07 +08:00
return x
2024-10-29 20:18:54 +08:00
EPOCH = 100
2024-10-29 20:05:56 +08:00
model = Net().to(device)
2024-10-28 19:58:07 +08:00
#使用交叉墒损失做损失函数
sunshi = torch.nn.CrossEntropyLoss()
#优化器:随机梯度下降
#lr=学习率momentum = 冲量
optimizer = torch.optim.SGD(model.parameters(),lr=0.25,momentum=0.25)
#训练
def train(epoch):
running_loss = 0.0
running_total=0
running_correct = 0
2024-10-29 20:18:54 +08:00
for batch_idx, (inputs, target) in enumerate(train_Dataloader, 0):
inputs = torch.stack([torch.tensor(img, dtype=torch.float) for img in inputs]).to(device)
target = torch.tensor(target, dtype=torch.long).to(device)
2024-10-28 19:58:07 +08:00
#梯度归零
optimizer.zero_grad()
outputs = model(inputs)
2024-10-29 20:05:56 +08:00
loss = sunshi(outputs,target)
2024-10-28 19:58:07 +08:00
#反向传播
2024-10-25 21:16:34 +08:00
loss.backward()
optimizer.step()
2024-10-28 19:58:07 +08:00
running_loss += loss.item()
#准确率
2024-10-29 20:05:56 +08:00
_,predicted = torch.max(outputs,dim=1)
2024-10-29 21:33:53 +08:00
running_total+=target.size(0)
2024-10-29 20:05:56 +08:00
running_correct += (predicted == target).sum().item()
2024-10-29 21:33:53 +08:00
print(f'[%d,%5d]:loss:%.3f,acc:%.2f',epoch+1,batch_idx+1,running_loss,running_correct/running_total)
2024-10-28 19:58:07 +08:00
#测试
def test():
correct =0
total = 0
with torch.no_grad():
2024-10-29 21:33:53 +08:00
for images, labels in test_DataLoader:
images = torch.stack([torch.tensor(img, dtype=torch.float) for img in images]).to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
2024-10-28 19:58:07 +08:00
outputs = model(images)
2024-10-29 21:33:53 +08:00
_,predicted = torch.max(outputs.data,dim=1)
2024-10-28 19:58:07 +08:00
total += labels.size(0)
correct +=(predicted == labels).sum().item()
accuracy = correct/total
2024-10-29 21:33:53 +08:00
print(f'[%d/%d]Accuracy: %.lf %%', epoch+1,EPOCH,accuracy)
2024-10-28 19:58:07 +08:00
return accuracy
2024-10-29 20:05:56 +08:00
if __name__ =='__main__':
2024-10-28 19:58:07 +08:00
acc_list_test =[]
for epoch in range(EPOCH):
train(epoch)
#每训练10轮测试一次
if epoch % 10 ==9:
acc_test = test()
2024-10-29 20:05:56 +08:00
acc_list_test.append(acc_test)
plt.plot(acc_list_test)
plt.xlabel('Epoch')
plt.ylabel('Accuracy On TestSet')
plt.show()