Skip to content

Commit

Permalink
Save activation images for each layer
Browse files Browse the repository at this point in the history
  • Loading branch information
hvy committed Jul 13, 2016
1 parent 16a5ce4 commit 89d6a60
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 227 deletions.
244 changes: 83 additions & 161 deletions VGGVisualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import chainer.links as L
import chainer.functions as F
from chainer import Variable
# from lib.chainer.functions.pooling.max_pooling_2d import max_pooling_2d
from max_pooling_2d import max_pooling_2d
from unpooling_2d import unpooling_2d

Expand All @@ -12,205 +11,128 @@
input_dimensions = (244, 244)


def from_VGG(vgg):
vgg_visualizer = VGGVisualizer()

vgg_visualizer.conv1_1 = vgg.conv1_1
vgg_visualizer.conv1_2 = vgg.conv1_2

vgg_visualizer.deconv1_2.W = vgg.conv1_2.W
# vgg_visualizer.deconv1_2.b = vgg.conv1_2.b
vgg_visualizer.deconv1_1.W = vgg.conv1_1.W
# vgg_visualizer.deconv1_1.b = vgg.conv1_1.b

vgg_visualizer.conv2_1 = vgg.conv2_1
vgg_visualizer.conv2_2 = vgg.conv2_2

vgg_visualizer.deconv2_2.W = vgg.conv2_2.W
# vgg_visualizer.deconv2_2.b = vgg.conv2_2.b
vgg_visualizer.deconv2_1.W = vgg.conv2_1.W
# vgg_visualizer.deconv2_1.b = vgg.conv2_1.b

vgg_visualizer.conv3_1 = vgg.conv3_1
vgg_visualizer.conv3_2 = vgg.conv3_2
vgg_visualizer.conv3_3 = vgg.conv3_3

vgg_visualizer.deconv3_3.W = vgg.conv3_3.W
# vgg_visualizer.deconv3_3.b = vgg.conv3_3.b
vgg_visualizer.deconv3_2.W = vgg.conv3_2.W
# vgg_visualizer.deconv3_2.b = vgg.conv3_2.b
vgg_visualizer.deconv3_1.W = vgg.conv3_1.W
# vgg_visualizer.deconv3_1.b = vgg.conv3_1.b

vgg_visualizer.conv4_1 = vgg.conv4_1
vgg_visualizer.conv4_2 = vgg.conv4_2
vgg_visualizer.conv4_3 = vgg.conv4_3

vgg_visualizer.deconv4_3.W = vgg.conv4_3.W
# vgg_visualizer.deconv4_3.b = vgg.conv4_3.b
vgg_visualizer.deconv4_2.W = vgg.conv4_2.W
# vgg_visualizer.deconv4_2.b = vgg.conv4_2.b
vgg_visualizer.deconv4_1.W = vgg.conv4_1.W
# vgg_visualizer.deconv4_1.b = vgg.conv4_1.b

vgg_visualizer.conv5_1 = vgg.conv5_1
vgg_visualizer.conv5_2 = vgg.conv5_2
vgg_visualizer.conv5_3 = vgg.conv5_3

vgg_visualizer.deconv5_3.W = vgg.conv5_3.W
# vgg_visualizer.deconv5_3.b = vgg.conv5_3.b
vgg_visualizer.deconv5_2.W = vgg.conv5_2.W
# vgg_visualizer.deconv5_2.b = vgg.conv5_2.b
vgg_visualizer.deconv5_1.W = vgg.conv5_1.W
# vgg_visualizer.deconv5_1.b = vgg.conv5_1.b

print '=================================='
print vgg_visualizer.deconv1_1.W.data.shape
print vgg.conv1_1.W.data.shape
print '=================================='
print '=================================='
print '------------------------------------------------------------'
print vgg.conv1_1.W.data[0][0][0][0]
print vgg_visualizer.conv1_1.W.data[0][0][0][0]
print vgg_visualizer.deconv1_1.W.data[0][0][0][0]
print '------------------------------------------------------------'
return vgg_visualizer


class VGGVisualizer(chainer.Chain):
class VGG(chainer.Chain):
def __init__(self):
super(VGGVisualizer, self).__init__(
super(VGG, self).__init__(
conv1_1=L.Convolution2D(3, 64, 3, stride=1, pad=1),
conv1_2=L.Convolution2D(64, 64, 3, stride=1, pad=1),
deconv1_2=L.Deconvolution2D(64, 64, 3, stride=1, pad=1, nobias=True),
deconv1_1=L.Deconvolution2D(64, 3, 3, stride=1, pad=1, nobias=True),

conv2_1=L.Convolution2D(64, 128, 3, stride=1, pad=1),
conv2_2=L.Convolution2D(128, 128, 3, stride=1, pad=1),
deconv2_2=L.Deconvolution2D(128, 128, 3, stride=1, pad=1, nobias=True),
deconv2_1=L.Deconvolution2D(128, 64, 3, stride=1, pad=1, nobias=True),

conv3_1=L.Convolution2D(128, 256, 3, stride=1, pad=1),
conv3_2=L.Convolution2D(256, 256, 3, stride=1, pad=1),
conv3_3=L.Convolution2D(256, 256, 3, stride=1, pad=1),
deconv3_3=L.Deconvolution2D(256, 256, 3, stride=1, pad=1, nobias=True),
deconv3_2=L.Deconvolution2D(256, 256, 3, stride=1, pad=1, nobias=True),
deconv3_1=L.Deconvolution2D(256, 128, 3, stride=1, pad=1, nobias=True),

conv4_1=L.Convolution2D(256, 512, 3, stride=1, pad=1),
conv4_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
conv4_3=L.Convolution2D(512, 512, 3, stride=1, pad=1),
deconv4_3=L.Deconvolution2D(512, 512, 3, stride=1, pad=1, nobias=True),
deconv4_2=L.Deconvolution2D(512, 512, 3, stride=1, pad=1, nobias=True),
deconv4_1=L.Deconvolution2D(512, 256, 3, stride=1, pad=1, nobias=True),

conv5_1=L.Convolution2D(512, 512, 3, stride=1, pad=1),
conv5_2=L.Convolution2D(512, 512, 3, stride=1, pad=1),
conv5_3=L.Convolution2D(512, 512, 3, stride=1, pad=1),
deconv5_3=L.Deconvolution2D(512, 512, 3, stride=1, pad=1, nobias=True),
deconv5_2=L.Deconvolution2D(512, 512, 3, stride=1, pad=1, nobias=True),
deconv5_1=L.Deconvolution2D(512, 512, 3, stride=1, pad=1, nobias=True)

fc6=L.Linear(25088, 4096),
fc7=L.Linear(4096, 4096),
fc8=L.Linear(4096, 1000)
)
self.visualize = 1
"""
fc6=L.Linear(25088, 4096),
fc7=L.Linear(4096, 4096),
fc8=L.Linear(4096, 1000)
"""
self.convs = [
['conv1_1', 'conv1_2'],
['conv2_1', 'conv2_2'],
['conv3_1', 'conv3_2', 'conv3_3'],
['conv4_1', 'conv4_2', 'conv4_3'],
['conv5_1', 'conv5_2', 'conv5_3']]

def __call__(self, x, t):
h = F.relu(self.conv1_1(x))
h = F.relu(self.conv1_2(h))
self.train = False
self.switches = []
self.unpooling_outsizes = []
self.added_deconv = False

"""
print('--- Before pooling (subset) ---')
print (h.data[0, 0, :6, :6])
"""
outsize1 = h.data.shape[2:]
h, indexes1 = F.max_pooling_2d(h, 2, stride=2)
def __call__(self, x, t=None, stop_layer=None):
self.switches = []
self.unpooling_outsizes = []

h = F.relu(self.conv2_1(h))
h = F.relu(self.conv2_2(h))
outsize2 = h.data.shape[2:]
h, indexes2 = F.max_pooling_2d(h, 2, stride=2)
# Forward pass through convolutional layers with ReLU and pooling
h = x
for i, layer in enumerate(self.convs):
for conv in layer:
h = F.relu(getattr(self, conv)(h))

prepooling_size = h.data.shape[2:]
self.unpooling_outsizes.append(prepooling_size)

h = F.relu(self.conv3_1(h))
h = F.relu(self.conv3_2(h))
h = F.relu(self.conv3_3(h))
outsize3 = h.data.shape[2:]
h, indexes3 = F.max_pooling_2d(h, 2, stride=2)
h, switches = F.max_pooling_2d(h, 2, stride=2)
self.switches.append(switches)

h = F.relu(self.conv4_1(h))
h = F.relu(self.conv4_2(h))
h = F.relu(self.conv4_3(h))
outsize4 = h.data.shape[2:]
h, indexes4 = F.max_pooling_2d(h, 2, stride=2)
if stop_layer == i + 1:
return h

h = F.relu(self.conv5_1(h))
h = F.relu(self.conv5_2(h))
h = F.relu(self.conv5_3(h))
outsize5 = h.data.shape[2:]
h, indexes5 = F.max_pooling_2d(h, 2, stride=2)
h = F.dropout(F.relu(self.fc6(h)), train=self.train, ratio=0.5)
h = F.dropout(F.relu(self.fc7(h)), train=self.train, ratio=0.5)
h = self.fc8(h)

h_tmp = h.data.copy()
print 'h.shape'
print h_tmp.shape
if self.train:
self.loss = F.softmax_cross_entropy(h, t)
self.acc = F.accuracy(h, t)
return self.loss
else:
self.pred = F.softmax(h)
return self.pred

def activations(self, x, layer):
if x.data.shape[0] != 1:
raise TypeError('Visualization is only supported for a single \
image at a time')

# Reconstruction
i = 10
import numpy as np
h = np.zeros(h_tmp.shape)
h[0][i] = h_tmp[0][i]
h = Variable(h)
self.add_deconv_layers()

h = F.unpooling_2d(h, indexes5, 2, stride=2, outsize=outsize5)
h = self.deconv5_3(F.relu(h))
h = self.deconv5_2(F.relu(h))
h = self.deconv5_1(F.relu(h))
# return h
# Forward pass
h = self(x, stop_layer=layer)

h = F.unpooling_2d(h, indexes4, 2, stride=2, outsize=outsize4)
h = self.deconv4_3(F.relu(h))
h = self.deconv4_2(F.relu(h))
h = self.deconv4_1(F.relu(h))
# Compute the activations for each feature map
h_data = h.data.copy()
xp = chainer.cuda.get_array_module(h.data)
zeros = xp.zeros_like(h.data)
convs = self.convs[:layer]
deconvs = [['de{}'.format(c) for c in conv] for conv in convs]

feat_maps = []

h = F.unpooling_2d(h, indexes3, 2, stride=2, outsize=outsize3)
h = self.deconv3_3(F.relu(h))
h = self.deconv3_2(F.relu(h))
h = self.deconv3_1(F.relu(h))
for fm in range(h.data.shape[1]): # For each feature map

print('Feature map {}'.format(fm))

h = F.unpooling_2d(h, indexes2, 2, stride=2, outsize=outsize2)
h = self.deconv2_2(F.relu(h))
h = self.deconv2_1(F.relu(h))
condition = zeros.copy()
condition[0][fm] = 1 # Keep one feature map and zero all other
h = Variable(xp.where(condition, h_data, zeros))

h = F.unpooling_2d(h, indexes1, 2, stride=2, outsize=outsize1)
h = self.deconv1_2(F.relu(h))
h = self.deconv1_1(F.relu(h))
for i, deconv in enumerate(reversed(deconvs)):
h = F.unpooling_2d(h, self.switches[layer-i-1], 2, stride=2,
outsize=self.unpooling_outsizes[layer-i-1])
for d in reversed(deconv):
h = getattr(self, d)(F.relu(h))

print(h.data)
print(h.data.shape)
feat_maps.append(h.data)

feat_maps = xp.array(feat_maps)
feat_maps = xp.rollaxis(feat_maps, 0, 2) # Batch to first axis

# Return first layer visualizations
return h
return Variable(feat_maps)

print('--- After pooling (subset) ---')
print h_prim.data[0, 0, :6, :6]
def add_deconv_layers(self, nobias=True):
"""Add a deconvolutional layer for each convolutional layer already
defined in the network."""
if self.added_deconv:
return

"""
h = F.dropout(F.relu(self.fc6(h)), train=self.train, ratio=0.5)
h = F.dropout(F.relu(self.fc7(h)), train=self.train, ratio=0.5)
h = self.fc8(h)
for layer in self.children():
if isinstance(layer, F.Convolution2D):
out_channels, in_channels, kh, kw = layer.W.data.shape
deconv = L.Deconvolution2D(out_channels, in_channels,
(kh, kw), stride=layer.stride,
pad=layer.pad,
initialW=layer.W.data,
nobias=nobias)
self.add_link('de{}'.format(layer.name), deconv)

if self.train:
self.loss = F.softmax_cross_entropy(h, t)
self.acc = F.accuracy(h, t)
return self.loss
else:
self.pred = F.softmax(h)
return self.pred
"""
self.added_deconv = True
3 changes: 3 additions & 0 deletions activations/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Activations

Feature map activations are stored in this directory.
45 changes: 45 additions & 0 deletions imgutil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import math
import numpy as np
import cv2 as cv
import matplotlib
matplotlib.use('Agg') # Workaround to save images when running over ssh sessions
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


def tile_ims(filename, directory):
"""Load all images in the given directory and tile them into one."""
ims = [mpimg.imread(os.path.join(directory, f)) for f in
sorted(os.listdir(directory))]
save_ims(filename, np.array(ims))


def save_im(filename, im):
# h, w, c = im.shape
cv.imwrite(filename, im)


def save_ims(filename, ims):
n, h, w, c = ims.shape

# Plot the images on a grid
rows = int(math.ceil(math.sqrt(n)))
cols = int(round(math.sqrt(n)))

# Each subplot should have the same resolutions as the image dimensions

# TODO: Consider proper heights and widths for the subplots
h = 64
w = 64

fig, axes = plt.subplots(rows, cols, figsize=(h, w))
fig.subplots_adjust(hspace=0, wspace=0)

for i, ax in enumerate(axes.flat):
ax.axis('off') # Hide x, y axes completely
if i < n:
ax.imshow(ims[i])

plt.savefig(filename, bbox_inches='tight')
plt.clf()
7 changes: 5 additions & 2 deletions max_pooling_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ def forward_cpu(self, x):
return y, self.indexes

def forward_gpu(self, x):
if cuda.cudnn_enabled and self.use_cudnn:
return super(MaxPooling2D, self).forward_gpu(x)
"""
Commented away since we need the indexes for the unpooling process.
if cuda.cudnn_enabled and self.use_cudnn:
return super(MaxPooling2D, self).forward_gpu(x), self.indexes
"""
n, c, h, w = x[0].shape
y_h = conv.get_conv_outsize(
h, self.kh, self.sy, self.ph, self.cover_all)
Expand Down
Loading

0 comments on commit 89d6a60

Please sign in to comment.