要跑很多次
Some checks failed
/ job1 (push) Failing after 9m11s

This commit is contained in:
kemna 2024-10-26 10:00:14 +08:00
parent 7fdb346762
commit 1c7cbe2987

View File

@ -28,7 +28,7 @@ trainData=torchvision.datasets.MNIST(path,train=True,transform=transform,downloa
testData=torchvision.datasets.MNIST(path,train=False,transform=transform,download=False)
#使用dataloader方法开始训练
#设定batch大小
BATCH_SIZE=2048
BATCH_SIZE=1000
#构建dataloader
TrainDataLoader = torch.utils.data.DataLoader(dataset = trainData,batch_size=BATCH_SIZE)
TestDataLoader = torch.utils.data.DataLoader(dataset=testData,batch_size=BATCH_SIZE)
@ -40,7 +40,7 @@ class Net(torch.nn.Module):
super(Net,self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1),
torch.nn.ReLU(),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size = 2,stride = 2),
#The size of the picture is 14x14
@ -104,5 +104,4 @@ for epochs in range(0,Epochs):
history['Test Accuracy'].append(testAccuracy.item())
processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f" %
(epochs,Epochs,loss.item(),accuracy.item(),testLoss.item(),testAccuracy.item()))
processBar.close()
torch.save(net,'./model.pth')
processBar.close()