parent
7fdb346762
commit
1c7cbe2987
7
main.py
7
main.py
@ -28,7 +28,7 @@ trainData=torchvision.datasets.MNIST(path,train=True,transform=transform,downloa
|
||||
testData=torchvision.datasets.MNIST(path,train=False,transform=transform,download=False)
|
||||
#使用dataloader方法开始训练
|
||||
#设定batch大小
|
||||
BATCH_SIZE=2048
|
||||
BATCH_SIZE=1000
|
||||
#构建dataloader
|
||||
TrainDataLoader = torch.utils.data.DataLoader(dataset = trainData,batch_size=BATCH_SIZE)
|
||||
TestDataLoader = torch.utils.data.DataLoader(dataset=testData,batch_size=BATCH_SIZE)
|
||||
@ -40,7 +40,7 @@ class Net(torch.nn.Module):
|
||||
super(Net,self).__init__()
|
||||
self.model = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=1,out_channels=16,kernel_size=3,stride=1,padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d(kernel_size = 2,stride = 2),
|
||||
|
||||
#The size of the picture is 14x14
|
||||
@ -104,5 +104,4 @@ for epochs in range(0,Epochs):
|
||||
history['Test Accuracy'].append(testAccuracy.item())
|
||||
processBar.set_description("[%d/%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f" %
|
||||
(epochs,Epochs,loss.item(),accuracy.item(),testLoss.item(),testAccuracy.item()))
|
||||
processBar.close()
|
||||
torch.save(net,'./model.pth')
|
||||
processBar.close()
|
Loading…
x
Reference in New Issue
Block a user