"""Dilated_ResNetV0s, implemented in Gluon."""
# pylint: disable=arguments-differ,unused-argument,missing-docstring
from __future__ import division
from mxnet.context import cpu
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm
__all__ = ['DilatedResNetV0', 'dilated_resnet18', 'dilated_resnet34',
'dilated_resnet50', 'dilated_resnet101',
'dilated_resnet152', 'DilatedBasicBlockV0', 'DilatedBottleneckV0']
class DilatedBasicBlockV0(HybridBlock):
"""DilatedResNetV0 DilatedBasicBlockV0
"""
expansion = 1
def __init__(self, inplanes, planes, strides=1, dilation=1, downsample=None, first_dilation=1,
norm_layer=None, **kwargs):
super(DilatedBasicBlockV0, self).__init__()
self.conv1 = nn.Conv2D(in_channels=inplanes, channels=planes,
kernel_size=3, strides=strides,
padding=dilation, dilation=dilation, use_bias=False)
self.bn1 = nn.BatchNorm(in_channels=planes)
self.relu = nn.Activation('relu')
self.conv2 = nn.Conv2D(in_channels=planes, channels=planes, kernel_size=3, strides=1,
padding=first_dilation, dilation=first_dilation, use_bias=False)
self.bn2 = nn.BatchNorm(in_channels=planes)
self.downsample = downsample
self.strides = strides
def hybrid_forward(self, F, x):
#def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
class DilatedBottleneckV0(HybridBlock):
"""DilatedResNetV0 DilatedBottleneckV0
"""
# pylint: disable=unused-argument
expansion = 4
def __init__(self, inplanes, planes, strides=1, dilation=1,
downsample=None, first_dilation=1, norm_layer=None, **kwargs):
super(DilatedBottleneckV0, self).__init__()
self.conv1 = nn.Conv2D(in_channels=inplanes, channels=planes, kernel_size=1, use_bias=False)
self.bn1 = nn.BatchNorm(in_channels=planes)
self.conv2 = nn.Conv2D(
in_channels=planes, channels=planes, kernel_size=3, strides=strides,
padding=dilation, dilation=dilation, use_bias=False)
self.bn2 = nn.BatchNorm(in_channels=planes)
self.conv3 = nn.Conv2D(
in_channels=planes, channels=planes * 4, kernel_size=1, use_bias=False)
self.bn3 = nn.BatchNorm(in_channels=planes * 4)
self.relu = nn.Activation('relu')
self.downsample = downsample
self.dilation = dilation
self.strides = strides
def hybrid_forward(self, F, x):
#def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = out + residual
out = self.relu(out)
return out
[docs]class DilatedResNetV0(HybridBlock):
"""Dilated Pre-trained DilatedResNetV0 Model, which preduces the strides of 8
featuremaps at conv5.
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
num_classes : int, default 1000
Number of classification classes.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
Reference:
- He, Kaiming, et al. "Deep residual learning for image recognition."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
"""
# pylint: disable=unused-variable
def __init__(self, block, layers, num_classes=1000, norm_layer=BatchNorm, **kwargs):
self.inplanes = 64
super(DilatedResNetV0, self).__init__()
with self.name_scope():
self.conv1 = nn.Conv2D(in_channels=3, channels=64, kernel_size=7, strides=2, padding=3,
use_bias=False)
self.bn1 = norm_layer(in_channels=64)
self.relu = nn.Activation('relu')
self.maxpool = nn.MaxPool2D(pool_size=3, strides=2, padding=1)
self.layer1 = self._make_layer(1, block, 64, layers[0], norm_layer=norm_layer)
self.layer2 = self._make_layer(2, block, 128, layers[1], strides=2,
norm_layer=norm_layer)
self.layer3 = self._make_layer(3, block, 256, layers[2], strides=1, dilation=2,
norm_layer=norm_layer)
self.layer4 = self._make_layer(4, block, 512, layers[3], strides=1, dilation=4,
norm_layer=norm_layer)
self.avgpool = nn.AvgPool2D(7)
self.flat = nn.Flatten()
self.fc = nn.Dense(in_units=512 * block.expansion, units=num_classes)
def _make_layer(self, stage_index, block, planes, blocks, strides=1, dilation=1,
norm_layer=None):
downsample = None
if strides != 1 or self.inplanes != planes * block.expansion:
downsample = nn.HybridSequential(prefix='down%d_'%stage_index)
with downsample.name_scope():
downsample.add(nn.Conv2D(in_channels=self.inplanes,
channels=planes * block.expansion,
kernel_size=1, strides=strides, use_bias=False))
downsample.add(norm_layer(in_channels=planes * block.expansion))
layers = nn.HybridSequential(prefix='layers%d_'%stage_index)
with layers.name_scope():
if dilation == 1 or dilation == 2:
layers.add(block(self.inplanes, planes, strides, dilation=1,
downsample=downsample, first_dilation=dilation,
norm_layer=norm_layer))
elif dilation == 4:
layers.add(block(self.inplanes, planes, strides, dilation=2,
downsample=downsample, first_dilation=dilation,
norm_layer=norm_layer))
else:
raise RuntimeError("=> unknown dilation size: {}".format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.add(block(self.inplanes, planes, dilation=dilation, first_dilation=dilation,
norm_layer=norm_layer))
return layers
[docs] def hybrid_forward(self, F, x):
#def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flat(x)
x = self.fc(x)
return x
[docs]def dilated_resnet18(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a DilatedResNetV0-18 model.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
"""
model = DilatedResNetV0(DilatedBasicBlockV0, [2, 2, 2, 2], **kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_params(get_model_file('resnet%d_v%d'%(18, 0),
root=root), ctx=ctx)
return model
[docs]def dilated_resnet34(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a DilatedResNetV0-34 model.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
"""
model = DilatedResNetV0(DilatedBasicBlockV0, [3, 4, 6, 3], **kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_params(get_model_file('resnet%d_v%d'%(34, 0),
root=root), ctx=ctx)
return model
[docs]def dilated_resnet50(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a DilatedResNetV0-50 model.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
"""
model = DilatedResNetV0(DilatedBottleneckV0, [3, 4, 6, 3], **kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_params(get_model_file('resnet%d_v%d'%(50, 0),
root=root), ctx=ctx)
return model
[docs]def dilated_resnet101(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a DilatedResNetV0-101 model.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
"""
model = DilatedResNetV0(DilatedBottleneckV0, [3, 4, 23, 3], **kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_params(get_model_file('resnet%d_v%d'%(101, 0),
root=root), ctx=ctx)
return model
[docs]def dilated_resnet152(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a DilatedResNetV0-152 model.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
"""
model = DilatedResNetV0(DilatedBottleneckV0, [3, 8, 36, 3], **kwargs)
if pretrained:
from ..model_store import get_model_file
model.load_params(get_model_file('resnet%d_v%d'%(152, 0),
root=root), ctx=ctx)
return model