本文共 2257 字,大约阅读时间需要 7 分钟。
import matplotlib.pyplot as plt
import numpy as np from torch import nn, optim from torch.autograd import Variable def test_network(net, trainloader):criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)dataiter = iter(trainloader)
images, labels = dataiter.next()# Create Variables for the inputs and targets
inputs = Variable(images) targets = Variable(images)# Clear the gradients from all Variables
optimizer.zero_grad()# Forward pass, then backward pass, then update weights
output = net.forward(inputs) loss = criterion(output, targets) loss.backward() optimizer.step()return True
def imshow(image, ax=None, title=None, normalize=True): """Imshow for Tensor.""" if ax is None: fig, ax = plt.subplots() image = image.numpy().transpose((1, 2, 0))if normalize:
mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) image = std * image + mean image = np.clip(image, 0, 1)ax.imshow(image)
ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.tick_params(axis='both', length=0) ax.set_xticklabels('') ax.set_yticklabels('')return ax
def view_recon(img, recon): ''' Function for displaying an image (as a PyTorch Tensor) and its reconstruction also a PyTorch Tensor '''fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True)
axes[0].imshow(img.numpy().squeeze()) axes[1].imshow(recon.data.numpy().squeeze()) for ax in axes: ax.axis('off') ax.set_adjustable('box-forced')def view_classify(img, ps, version="MNIST"):
''' Function for viewing an image and it's predicted classes. ''' ps = ps.data.numpy().squeeze()fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze()) ax1.axis('off') ax2.barh(np.arange(10), ps) ax2.set_aspect(0.1) ax2.set_yticks(np.arange(10)) if version == "MNIST": ax2.set_yticklabels(np.arange(10)) elif version == "Fashion": ax2.set_yticklabels(['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot'], size='small'); ax2.set_title('Class Probability') ax2.set_xlim(0, 1.1)plt.tight_layout()
————————————————
原文链接:https://blog.csdn.net/qq_28368377/article/details/105588250