63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
import torch
|
|
import numpy
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms,datasets
|
|
from torchvision.datasets import MNIST
|
|
import matplotlib.pyplot as plt
|
|
data_tf=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])
|
|
class Net(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Net,self).__init__()
|
|
self.fc1=torch.nn.Linear(28*28,64)
|
|
self.fc2=torch.nn.Linear(64,64)
|
|
self.fc3=torch.nn.Linear(64,64)
|
|
self.fc4=torch.nn.Linear(64,10)
|
|
def forward(self,x):
|
|
x=torch.nn.functional.relu(self.fc1(x))
|
|
x=torch.nn.functional.relu(self.fc2(x))
|
|
x=torch.nn.functional.relu(self.fc3(x))
|
|
x=torch.nn.functional.log_softmax(self.fc4(x),dim=1)
|
|
return x
|
|
|
|
def get_data_loader(is_train):
|
|
to_tensor = transforms.Compose([transforms.ToTensor()])
|
|
data_set = MNIST("", is_train, transform=to_tensor, download=True)
|
|
return DataLoader(data_set, batch_size=15, shuffle=True)
|
|
|
|
def evaluate(test_data,net):
|
|
n_correct=0
|
|
n_total=1
|
|
with torch.no_grad():
|
|
for (x,y) in test_data:
|
|
outputs=net.forward(x.view(-1,28*28))
|
|
for i,output in enumerate(outputs):
|
|
if torch.argmax(output)== y[i]:
|
|
n_correct+=1
|
|
n_total+=1
|
|
return n_correct/n_total
|
|
|
|
def main():
|
|
train_data =get_data_loader(is_train=True)
|
|
|
|
test_data =get_data_loader(is_train=False)
|
|
net =Net()
|
|
print("initial caauracy",evaluate(test_data,net))
|
|
optimizer = torch.optim.Adam(net.parameters(),lr=0.001)
|
|
for epoch in range(2):
|
|
for (x,y) in train_data:
|
|
net.zero_grad()
|
|
output=net.forward(x.view(-1,28*28))
|
|
loss=torch.nn.functional.nll_loss(output,y)
|
|
loss.backward()
|
|
optimizer.step()
|
|
print("epoch",epoch,"accuracy:",evaluate(test_data,net))
|
|
for (n,(x,_)) in enumerate(test_data):
|
|
if n>3:
|
|
break
|
|
predict=torch.argmax(net.forward(x[0].view(-1,28*28)))
|
|
plt.figure(n)
|
|
plt.imshow(x[0].view(28,28))
|
|
plt.title("prediction:"+str(int(predict)))
|
|
plt.show()
|
|
if __name__=="__main__":
|
|
main() |