2024-10-26 11:34:07 +08:00

102 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#这些注释都是原来的错误代码。最后是找了AI 解决的(🤓)
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
#import matplotlib.pyplot as plt #可有可无
device = 'cuda:0'
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(28 * 28, 64)
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 10)
def forward(self, x):
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
def get_data_loader(is_train):
to_tensor = transforms.Compose([transforms.ToTensor()])
data_set = MNIST('', is_train, transform=to_tensor, download=True)
return DataLoader(data_set, batch_size=100, shuffle=True)
def evaluate(test_data, net):
n_correct = 0
n_total = 0
with torch.no_grad():
for (x, y) in test_data:
# 将数据发送到GPU(这个02是没有的)
x = x.view(-1, 28 * 28).to(device)
y = y.to(device)
output = net(x)
_, predicted = torch.max(output, 1)
n_correct += (predicted == y).sum().item()
n_total += y.size(0)
return n_correct / n_total
# def evaluate(test_data, net):
# n_correct = 0
# n_total = 0
# with torch.no_grad():
# for (x, y) in test_data:
# output = net.forward(x.view(-1, 28 * 28))
# for i, output in enumerate(output):
# if torch.argmax(output) == y[i]:
# n_correct += 1
# n_total += 1
# return n_correct / n_total
def main():
train_data = get_data_loader(is_train=True)
test_data = get_data_loader(is_train=False)
net = Net().to(device)
print('initial accuracy', evaluate(test_data, net))
optimizer = torch.optim.Adam(net.parameters())
for epoch in range(100):
for (x, y) in train_data:
# 将数据发送到GPU(这个02是没有的)
x = x.view(-1, 28 * 28).to(device)
y = y.to(device)
net.zero_grad()
output = net(x)
loss = torch.nn.functional.nll_loss(output, y)
loss.backward()#02没有这个导致正确率也没有提高
optimizer.step()
print('epoch', epoch, 'accuracy', evaluate(test_data, net))
torch.save(net,'./model.pth')
if __name__ == "__main__":
main()
# def main():
# train_data = get_data_loader(is_train=True)
# test_data = get_data_loader(is_train=False)
# net = Net().to(device)
# print('initial accuracy', evaluate(test_data, net))
# optimizer = torch.optim.Adam(net.parameters())
# for epoch in range(3):
# for (x, y) in train_data:
# net.zero_grad()
# output = net.forward(x.view(-1, 28 * 28))
# loss = torch.nn.functional.nll_loss(output, y)
# loss.backward()
# optimizer.step()
# print('epoch', epoch, 'accuracy', evaluate(test_data, net))
# for (n, (x, _)) in enumerate(test_data):
# if n > 100:
# break
# predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
# plt.figure(n)
# plt.imshow(x[0].view(28, 28))
# plt.title('prediction:' + str(int(predict)))
# plt.show()
if __name__ == "__main__":
main()