Created
May 5, 2020 02:39
-
-
Save xuhdev/58c494ccfb6ed3f8236b85fc1e4964b3 to your computer and use it in GitHub Desktop.
Revisions
-
xuhdev created this gist
May 5, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,31 @@ for epoch in range(1): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(train_loader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward outputs = net(inputs) loss = criterion(outputs, labels) # backward (differentiate) loss.backward() # optimize (update) optimizer.step() # print statistics running_loss += loss.item() if i % 3000 == 2999: # print every 3000 mini-batches print(f'Epoch: {epoch + 1}, Iteration: {i + 1}, loss: {running_loss / 3000}') running_loss = 0.0 # Test accuracy correct = 0 total = 0 with torch.no_grad(): for data in test_loader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) # The label with the maximum probability is predicted total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the test images: {(100 * correct / total)} %')