Skip to content

Commit

Permalink
ref pad added
Browse files Browse the repository at this point in the history
  • Loading branch information
Samet Hicsonmez committed Jan 5, 2019
1 parent 7ad01d0 commit 3498dfa
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,12 @@ class BasicBlock_sam(nn.Module):

def __init__(self, in_planes, planes, stride=1):
super(BasicBlock_sam, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

self.rp1 = nn.ReflectionPad2d(1)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=0, bias=False)
self.bn1 = nn.InstanceNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.rp2 = nn.ReflectionPad2d(1)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=0, bias=False)
self.bn2 = nn.InstanceNorm2d(planes)
self.out_planes = planes

Expand All @@ -529,8 +532,8 @@ def __init__(self, in_planes, planes, stride=1):
)

def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = F.relu(self.bn1(self.conv1(self.rp1(x))))
out = self.bn2(self.conv2(self.rp2(out)))
inputt = self.shortcut(x)
catted = torch.cat((out, inputt), 1)
#out = F.relu(out)
Expand Down

0 comments on commit 3498dfa

Please sign in to comment.