From 3498dfa87c9d041dc9c08beefd34e3c15d7e3ca2 Mon Sep 17 00:00:00 2001 From: Samet Hicsonmez Date: Sat, 5 Jan 2019 19:44:46 +0300 Subject: [PATCH] ref pad added --- models/networks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/networks.py b/models/networks.py index 06c55ce4..8302af81 100755 --- a/models/networks.py +++ b/models/networks.py @@ -502,9 +502,12 @@ class BasicBlock_sam(nn.Module): def __init__(self, in_planes, planes, stride=1): super(BasicBlock_sam, self).__init__() - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + + self.rp1 = nn.ReflectionPad2d(1) + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=0, bias=False) self.bn1 = nn.InstanceNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.rp2 = nn.ReflectionPad2d(1) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0, bias=False) self.bn2 = nn.InstanceNorm2d(planes) self.out_planes = planes @@ -529,8 +532,8 @@ def __init__(self, in_planes, planes, stride=1): ) def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) + out = F.relu(self.bn1(self.conv1(self.rp1(x)))) + out = self.bn2(self.conv2(self.rp2(out))) inputt = self.shortcut(x) catted = torch.cat((out, inputt), 1) #out = F.relu(out)