parent
d3e281700e
commit
8dd08b8fbf
87
main.py
87
main.py
@ -1,38 +1,86 @@
|
||||
#导入模块
|
||||
import torch
|
||||
import torchvision
|
||||
import os
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import gzip
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
import matplotlib as plt
|
||||
from matplotlib import pyplot as plt
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
#使用torchvisiom.transform将图片转换为张量
|
||||
|
||||
transform =torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.5],std=[0.5])])
|
||||
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
|
||||
#构建数据集(使用MINIST)
|
||||
path ='./MNIST'
|
||||
path ='./MNIST/raw'
|
||||
|
||||
EPOCH =10
|
||||
Batch_Size = 64
|
||||
#创建dataset
|
||||
class DataSet:
|
||||
def __init__(self,floder,data_name,label_name,transform=None):
|
||||
self.floder = floder
|
||||
self.transform = transform
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, data_path, train=True):
|
||||
self.data_path = data_path
|
||||
self.train = train
|
||||
self.images = []
|
||||
self.labels = []
|
||||
self.load_data()
|
||||
|
||||
def load_data(self):
|
||||
if self.train:
|
||||
file_name = 'train-images-idx3-ubyte.gz'
|
||||
label_file_name = 'train-labels-idx1-ubyte.gz'
|
||||
else:
|
||||
file_name = 't10k-images-idx3-ubyte.gz'
|
||||
label_file_name = 't10k-labels-idx1-ubyte.gz'
|
||||
|
||||
with gzip.open(os.path.join(self.data_path, file_name), 'rb') as f:
|
||||
self.images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
|
||||
|
||||
with gzip.open(os.path.join(self.data_path, label_file_name), 'rb') as f:
|
||||
self.labels = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.images[idx], self.labels[idx]
|
||||
|
||||
#创建dataloader
|
||||
class DataLoader:
|
||||
def __init__(self,type,batch_size,is_shuffle):
|
||||
data_name = 'train-images-idx3-ubyte.gz' if type=='train' else 't10k-images-idx3-ubyte.gz'
|
||||
label_name = 'train-labels-idx1-ubyte.gz' if type=='train' else 't10k-labels-idx1-ubyte.gz'
|
||||
|
||||
def __init__(self,dataset,batch_size,is_shuffle):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.is_shuffle = is_shuffle
|
||||
|
||||
self.data_indices = list(range(len(dataset)))
|
||||
#打乱数据集
|
||||
def shuffle_data(self):
|
||||
if self.is_shuffle:
|
||||
random.shuffle(self.data_indices)
|
||||
#迭代数据集
|
||||
def __iter__(self):
|
||||
self.shuffle_data()
|
||||
for i in range(0, len(self.dataset), self.batch_size):
|
||||
batch_indices = self.data_indices[i:i + self.batch_size]
|
||||
batch_images = []
|
||||
batch_labels = []
|
||||
for idx in batch_indices:
|
||||
image, label = self.dataset[idx]
|
||||
batch_images.append(image)
|
||||
batch_labels.append(label)
|
||||
yield (batch_images, batch_labels)
|
||||
#下载数据集
|
||||
#下载训练集
|
||||
trainData=torchvision.datasets.MNIST(path,train=True,transform=transform,download=True)
|
||||
trainData=Dataset(path,train=True)
|
||||
#下载测试集
|
||||
testData=torchvision.datasets.MNIST(path,train=False,transform=transform,download=False)
|
||||
testData=Dataset(path,train=False)
|
||||
#shuffle=true是用来打乱数据集的
|
||||
train_Dataloader = DataLoader(train,batch_size = Batch_Size,shuffle = True)
|
||||
test_DataLoader = DataLoader(test,batch_size=Batch_Size,shuffle=False)
|
||||
train_Dataloader = DataLoader(trainData,batch_size=Batch_Size,is_shuffle=True)
|
||||
test_DataLoader = DataLoader(testData,batch_size=Batch_Size,is_shuffle=False)
|
||||
class Net(torch.nn.Module):
|
||||
#构造函数
|
||||
def __init__(self):
|
||||
@ -60,7 +108,7 @@ class Net(torch.nn.Module):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
EPOCH = 100
|
||||
model = Net().to(device)
|
||||
#使用交叉墒损失做损失函数
|
||||
sunshi = torch.nn.CrossEntropyLoss()
|
||||
@ -72,10 +120,9 @@ def train(epoch):
|
||||
running_loss = 0.0
|
||||
running_total=0
|
||||
running_correct = 0
|
||||
for batch_idx,data in enumerate(train_Dataloader,0):
|
||||
inputs, target = data
|
||||
inputs = inputs.to(device)
|
||||
target = target.to(device)
|
||||
for batch_idx, (inputs, target) in enumerate(train_Dataloader, 0):
|
||||
inputs = torch.stack([torch.tensor(img, dtype=torch.float) for img in inputs]).to(device)
|
||||
target = torch.tensor(target, dtype=torch.long).to(device)
|
||||
#梯度归零
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user