-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconvis.py
110 lines (87 loc) · 4.02 KB
/
convis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-input_image", help="Input target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-layer", help="layers for examination", default='relu2_2')
parser.add_argument("-pooling", help="max or avg pooling", type=str, default='max')
parser.add_argument("-output_image", default='out.png')
parser.add_argument("-output_dir", default='output')
parser.add_argument("-disable_check", action='store_true')
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
# Build the model definition and setup pooling layers:
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, 'c', params.disable_check)
img, image_size = preprocess(params.input_image, params.image_size)
output_filename, file_extension = os.path.splitext(params.output_image)
try:
os.makedirs(params.output_dir)
except OSError:
pass
cnn = copy.deepcopy(cnn)
net = nn.Sequential()
c, r, p = 0, 0, 0
convName, reluName, poolName = None, None, None
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
net.add_module(str(len(net)), layer)
convName = layerList['C'][c]
c+=1
if isinstance(layer, nn.ReLU):
net.add_module(str(len(net)), layer)
reluName = layerList['R'][r]
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
poolName = layerList['P'][p]
p+=1
if convName == params.layer or reluName == params.layer or poolName == params.layer:
break
# Get the activations
y = net(img).squeeze(0)
n = y.size(0)
for i in range(n):
y3 = torch.Tensor(3, y.size(1), y.size(2))
y1 = y.clone().narrow(0,i,1)
y3[0] = y1
y3[1] = y1
y3[2] = y1
filename = str(params.output_dir) + "/" + str(output_filename) + "-" + str(params.layer) + "-" + str(i) + file_extension
deprocess(y3, image_size, filename)
print("Saving image: " + filename)
if i == (n-1):
break
# Preprocess an image before passing it to a model.
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR,
# and subtract the mean pixel.
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])])
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0)
return tensor, image_size
# Undo the above preprocessing and save the tensor as an image:
def deprocess(output_tensor, image_size, output_name):
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1]) ]) # Add BGR
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]) ])
ResizeImage = transforms.Compose([transforms.Resize(image_size)])
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0))) / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor)
image = ResizeImage(image)
image.save(str(output_name))
if __name__ == "__main__":
main()