博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
torch helper文件
阅读量:4145 次
发布时间:2019-05-25

本文共 2257 字,大约阅读时间需要 7 分钟。

import matplotlib.pyplot as plt

import numpy as np
from torch import nn, optim
from torch.autograd import Variable

def test_network(net, trainloader):

    criterion = nn.MSELoss()

    optimizer = optim.Adam(net.parameters(), lr=0.001)

    dataiter = iter(trainloader)

    images, labels = dataiter.next()

    # Create Variables for the inputs and targets

    inputs = Variable(images)
    targets = Variable(images)

    # Clear the gradients from all Variables

    optimizer.zero_grad()

    # Forward pass, then backward pass, then update weights

    output = net.forward(inputs)
    loss = criterion(output, targets)
    loss.backward()
    optimizer.step()

    return True

def imshow(image, ax=None, title=None, normalize=True):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:

        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')

    return ax

def view_recon(img, recon):
    ''' Function for displaying an image (as a PyTorch Tensor) and its
        reconstruction also a PyTorch Tensor
    '''

    fig, axes = plt.subplots(ncols=2, sharex=True, sharey=True)

    axes[0].imshow(img.numpy().squeeze())
    axes[1].imshow(recon.data.numpy().squeeze())
    for ax in axes:
        ax.axis('off')
        ax.set_adjustable('box-forced')

def view_classify(img, ps, version="MNIST"):

    ''' Function for viewing an image and it's predicted classes.
    '''
    ps = ps.data.numpy().squeeze()

    fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)

    ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
    ax1.axis('off')
    ax2.barh(np.arange(10), ps)
    ax2.set_aspect(0.1)
    ax2.set_yticks(np.arange(10))
    if version == "MNIST":
        ax2.set_yticklabels(np.arange(10))
    elif version == "Fashion":
        ax2.set_yticklabels(['T-shirt/top',
                            'Trouser',
                            'Pullover',
                            'Dress',
                            'Coat',
                            'Sandal',
                            'Shirt',
                            'Sneaker',
                            'Bag',
                            'Ankle Boot'], size='small');
    ax2.set_title('Class Probability')
    ax2.set_xlim(0, 1.1)

    plt.tight_layout()

————————————————

原文链接:https://blog.csdn.net/qq_28368377/article/details/105588250

 

 

你可能感兴趣的文章
C 语言学习 --设置文本框内容及进制转换
查看>>
C 语言 学习---判断文本框取得的数是否是整数
查看>>
C 语言 学习---ComboBox相关、简单计算器
查看>>
C 语言 学习---ComboBox相关、简易“假”管理系统
查看>>
C 语言 学习---回调、时间定时更新程序
查看>>
C 语言 学习---复选框及列表框的使用
查看>>
第十一章 - 直接内存
查看>>
JDBC核心技术 - 上篇
查看>>
一篇搞懂Java反射机制
查看>>
Single Number II --出现一次的数(重)
查看>>
Palindrome Partitioning --回文切割 深搜(重重)
查看>>
对话周鸿袆:从程序员创业谈起
查看>>
Mysql中下划线问题
查看>>
Xcode 11 报错,提示libstdc++.6 缺失,解决方案
查看>>
idea的安装以及简单使用
查看>>
Windows mysql 安装
查看>>
python循环语句与C语言的区别
查看>>
vue 项目中图片选择路径位置static 或 assets区别
查看>>
vue项目打包后无法运行报错空白页面
查看>>
Vue 解决部署到服务器后或者build之后Element UI图标不显示问题(404错误)
查看>>