查看模型中间层输出
创建时间:2020-01-19 00:00
字数:265
阅读:
使用代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms from PIL import Image import numpy as np import matplotlib.pyplot as plt from torch.autograd import Variable # 图片测试 img = 'data//images//dog.jpg' imgdata = Image.open(img).convert('RGB') out = imgdata.resize((84,84),Image.ANTIALIAS) imgnum = np.array(out) imgnum = imgnum.transpose((2,0,1)) imgtor = torch.from_numpy(imgnum) imgtor = imgtor.squeeze(1) imgtor = torch.unsqueeze(imgtor, 0) imgtor = imgtor.float() imgnum.shape plt.imshow(imgnum[2]) # 中间层特征提取 class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): super(FeatureExtractor, self).__init__() self.submodule = submodule self.extracted_layers = extracted_layers # 自己修改forward函数 def forward(self, x): outputs = [] for name, module in self.submodule._modules.items(): if name is "fc": x = x.view(x.size(0), -1) x = module(x) if name in self.extracted_layers: outputs.append(x) return outputs extract_list = ["conv1", "maxpool", "layer1", "avgpool", "fc"] model = models.resnet18(pretrained=True) extract_result = FeatureExtractor(model, extract_list) show = extract_result(imgtor)[0].detach().numpy() plt.imshow(show[0][7], cmap='gray') plt.show()
微信:宏沉一笑
公众号:漫步之行
签名:Smile every day 名字:宏沉一笑 邮箱:whghcyx@outlook.com 个人网站:https://whg555.github.io
转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。可以在下面评论区评论,也可以邮件至 whghcyx@outlook.com