bug一堆
All checks were successful
/ job1 (push) Successful in 32m10s

This commit is contained in:
kemna 2024-10-29 21:33:53 +08:00
parent 8dd08b8fbf
commit 7d746726b9

29
main.py
View File

@ -17,7 +17,7 @@ transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[
path ='./MNIST/raw'
EPOCH =10
Batch_Size = 64
Batch_Size = 10
#创建dataset
class Dataset:
@ -73,6 +73,7 @@ class DataLoader:
batch_images.append(image)
batch_labels.append(label)
yield (batch_images, batch_labels)
torch.tensor(batch_labels)
#下载数据集
#下载训练集
trainData=Dataset(path,train=True)
@ -87,18 +88,17 @@ class Net(torch.nn.Module):
#继承父类
super(Net,self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 10, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(10, 20, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(20, 40, kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2),
)
self.fc = torch.nn.Sequential(
torch.nn.Linear(320, 50),
torch.nn.Linear(50, 10),
torch.nn.Linear(64,10),
)
def forward(self,x):
batch_size = x.size(0)
@ -135,22 +135,23 @@ def train(epoch):
running_loss += loss.item()
#准确率
_,predicted = torch.max(outputs,dim=1)
running_total+=inputs.shape[0]
running_total+=target.size(0)
running_correct += (predicted == target).sum().item()
print('[%d,%5d]:loss:%.3f,acc:%.2f',epoch+1,batch_idx+1,running_loss,running_correct/running_total)
print(f'[%d,%5d]:loss:%.3f,acc:%.2f',epoch+1,batch_idx+1,running_loss,running_correct/running_total)
#测试
def test():
correct =0
total = 0
with torch.no_grad():
for data in test_DataLoader:
images,labels = data
for images, labels in test_DataLoader:
images = torch.stack([torch.tensor(img, dtype=torch.float) for img in images]).to(device)
labels = torch.tensor(labels, dtype=torch.long).to(device)
outputs = model(images)
predicted = torch.max(outputs.data,dim=1)
_,predicted = torch.max(outputs.data,dim=1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()
accuracy = correct/total
print('[%d/%d]Accuracy: %.lf %%', epoch+1,EPOCH,accuracy)
print(f'[%d/%d]Accuracy: %.lf %%', epoch+1,EPOCH,accuracy)
return accuracy