查看模型中间层输出

  1. 使用代码

使用代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models, transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable

# 图片测试
img = 'data//images//dog.jpg'

imgdata = Image.open(img).convert('RGB')
out = imgdata.resize((84,84),Image.ANTIALIAS)
imgnum = np.array(out)
imgnum = imgnum.transpose((2,0,1))
imgtor = torch.from_numpy(imgnum)
imgtor = imgtor.squeeze(1)
imgtor = torch.unsqueeze(imgtor, 0)
imgtor = imgtor.float()
imgnum.shape

plt.imshow(imgnum[2])

# 中间层特征提取
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers

# 自己修改forward函数
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
if name is "fc": x = x.view(x.size(0), -1)
x = module(x)
if name in self.extracted_layers:
outputs.append(x)
return outputs

extract_list = ["conv1", "maxpool", "layer1", "avgpool", "fc"]
model = models.resnet18(pretrained=True)
extract_result = FeatureExtractor(model, extract_list)

show = extract_result(imgtor)[0].detach().numpy()

plt.imshow(show[0][7], cmap='gray')
plt.show()

微信:宏沉一笑
公众号:漫步之行

签名:Smile every day
名字:宏沉一笑
邮箱:whghcyx@outlook.com
个人网站:https://whg555.github.io



转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。可以在下面评论区评论,也可以邮件至 whghcyx@outlook.com

文章标题:查看模型中间层输出

文章字数:265

本文作者:宏沉一笑

发布时间:2020-01-19, 00:00:00

最后更新:2024-03-21, 12:53:34

原始链接:https://whghcyx.gitee.io/2020/01/19/AI-2020-1-19-%E6%9F%A5%E7%9C%8B%E6%A8%A1%E5%9E%8B%E4%B8%AD%E9%97%B4%E5%B1%82%E8%BE%93%E5%87%BA/

版权声明: "署名-非商用-相同方式共享 4.0" 转载请保留原文链接及作者。

目录
×

喜欢就点赞,疼爱就打赏