bug一堆
Some checks failed
/ job1 (push) Failing after 5m47s

This commit is contained in:
kemna 2024-10-29 20:18:54 +08:00
parent d3e281700e
commit 8dd08b8fbf

87
main.py
View File

@ -1,38 +1,86 @@
#导入模块
import torch
import torchvision
import os
from torchvision import transforms
import random
import gzip
import numpy as np
from enum import Enum
import matplotlib as plt
from matplotlib import pyplot as plt
import torch.nn.functional as F
from PIL import Image
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#使用torchvisiom.transform将图片转换为张量
transform =torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.5],std=[0.5])])
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
#构建数据集使用MINIST
path ='./MNIST'
path ='./MNIST/raw'
EPOCH =10
Batch_Size = 64
#创建dataset
class DataSet:
def __init__(self,floder,data_name,label_name,transform=None):
self.floder = floder
self.transform = transform
class Dataset:
def __init__(self, data_path, train=True):
self.data_path = data_path
self.train = train
self.images = []
self.labels = []
self.load_data()
def load_data(self):
if self.train:
file_name = 'train-images-idx3-ubyte.gz'
label_file_name = 'train-labels-idx1-ubyte.gz'
else:
file_name = 't10k-images-idx3-ubyte.gz'
label_file_name = 't10k-labels-idx1-ubyte.gz'
with gzip.open(os.path.join(self.data_path, file_name), 'rb') as f:
self.images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
with gzip.open(os.path.join(self.data_path, label_file_name), 'rb') as f:
self.labels = np.frombuffer(f.read(), np.uint8, offset=8)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx]
#创建dataloader
class DataLoader:
def __init__(self,type,batch_size,is_shuffle):
data_name = 'train-images-idx3-ubyte.gz' if type=='train' else 't10k-images-idx3-ubyte.gz'
label_name = 'train-labels-idx1-ubyte.gz' if type=='train' else 't10k-labels-idx1-ubyte.gz'
def __init__(self,dataset,batch_size,is_shuffle):
self.dataset = dataset
self.batch_size = batch_size
self.is_shuffle = is_shuffle
self.data_indices = list(range(len(dataset)))
#打乱数据集
def shuffle_data(self):
if self.is_shuffle:
random.shuffle(self.data_indices)
#迭代数据集
def __iter__(self):
self.shuffle_data()
for i in range(0, len(self.dataset), self.batch_size):
batch_indices = self.data_indices[i:i + self.batch_size]
batch_images = []
batch_labels = []
for idx in batch_indices:
image, label = self.dataset[idx]
batch_images.append(image)
batch_labels.append(label)
yield (batch_images, batch_labels)
#下载数据集
#下载训练集
trainData=torchvision.datasets.MNIST(path,train=True,transform=transform,download=True)
trainData=Dataset(path,train=True)
#下载测试集
testData=torchvision.datasets.MNIST(path,train=False,transform=transform,download=False)
testData=Dataset(path,train=False)
#shuffle=true是用来打乱数据集的
train_Dataloader = DataLoader(train,batch_size = Batch_Size,shuffle = True)
test_DataLoader = DataLoader(test,batch_size=Batch_Size,shuffle=False)
train_Dataloader = DataLoader(trainData,batch_size=Batch_Size,is_shuffle=True)
test_DataLoader = DataLoader(testData,batch_size=Batch_Size,is_shuffle=False)
class Net(torch.nn.Module):
#构造函数
def __init__(self):
@ -60,7 +108,7 @@ class Net(torch.nn.Module):
x = self.fc(x)
return x
EPOCH = 100
model = Net().to(device)
#使用交叉墒损失做损失函数
sunshi = torch.nn.CrossEntropyLoss()
@ -72,10 +120,9 @@ def train(epoch):
running_loss = 0.0
running_total=0
running_correct = 0
for batch_idx,data in enumerate(train_Dataloader,0):
inputs, target = data
inputs = inputs.to(device)
target = target.to(device)
for batch_idx, (inputs, target) in enumerate(train_Dataloader, 0):
inputs = torch.stack([torch.tensor(img, dtype=torch.float) for img in inputs]).to(device)
target = torch.tensor(target, dtype=torch.long).to(device)
#梯度归零
optimizer.zero_grad()