# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable= arguments-differ,unused-argument
"""ResNets, implemented in Gluon."""
from __future__ import division
__all__ = ['get_cifar_resnet',
'cifar_resnet20_v1', 'cifar_resnet56_v1', 'cifar_resnet110_v1',
'cifar_resnet20_v2', 'cifar_resnet56_v2', 'cifar_resnet110_v2']
import os
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet import cpu
# Helpers
def _conv3x3(channels, stride, in_channels):
return nn.Conv2D(channels, kernel_size=3, strides=stride, padding=1,
use_bias=False, in_channels=in_channels)
# Blocks
class CIFARBasicBlockV1(HybridBlock):
r"""BasicBlock V1 from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
This is used for ResNet V1 for 18, 34 layers.
Parameters
----------
channels : int
Number of output channels.
stride : int
Stride size.
downsample : bool, default False
Whether to downsample the input.
in_channels : int, default 0
Number of input channels. Default is 0, to infer from the graph.
"""
def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
super(CIFARBasicBlockV1, self).__init__(**kwargs)
self.body = nn.HybridSequential(prefix='')
self.body.add(_conv3x3(channels, stride, in_channels))
self.body.add(nn.BatchNorm())
self.body.add(nn.Activation('relu'))
self.body.add(_conv3x3(channels, 1, channels))
self.body.add(nn.BatchNorm())
if downsample:
self.downsample = nn.HybridSequential(prefix='')
self.downsample.add(nn.Conv2D(channels, kernel_size=1, strides=stride,
use_bias=False, in_channels=in_channels))
self.downsample.add(nn.BatchNorm())
else:
self.downsample = None
def hybrid_forward(self, F, x):
"""Hybrid forward"""
residual = x
x = self.body(x)
if self.downsample:
residual = self.downsample(residual)
x = F.Activation(residual+x, act_type='relu')
return x
class CIFARBasicBlockV2(HybridBlock):
r"""BasicBlock V2 from
`"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
This is used for ResNet V2 for 18, 34 layers.
Parameters
----------
channels : int
Number of output channels.
stride : int
Stride size.
downsample : bool, default False
Whether to downsample the input.
in_channels : int, default 0
Number of input channels. Default is 0, to infer from the graph.
"""
def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
super(CIFARBasicBlockV2, self).__init__(**kwargs)
self.bn1 = nn.BatchNorm()
self.conv1 = _conv3x3(channels, stride, in_channels)
self.bn2 = nn.BatchNorm()
self.conv2 = _conv3x3(channels, 1, channels)
if downsample:
self.downsample = nn.Conv2D(channels, 1, stride, use_bias=False,
in_channels=in_channels)
else:
self.downsample = None
def hybrid_forward(self, F, x):
"""Hybrid forward"""
residual = x
x = self.bn1(x)
x = F.Activation(x, act_type='relu')
x = self.conv1(x)
x = self.bn2(x)
x = F.Activation(x, act_type='relu')
x = self.conv2(x)
if self.downsample:
residual = self.downsample(residual)
return x + residual
# Nets
class CIFARResNetV1(HybridBlock):
r"""ResNet V1 model from
`"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
Parameters
----------
block : HybridBlock
Class for the residual block. Options are CIFARBasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
channels : list of int
Numbers of channels in each block. Length should be one larger than layers list.
classes : int, default 1000
Number of classification classes.
"""
def __init__(self, block, layers, channels, classes=10, **kwargs):
super(CIFARResNetV1, self).__init__(**kwargs)
assert len(layers) == len(channels) - 1
with self.name_scope():
self.features = nn.HybridSequential(prefix='')
self.features.add(nn.Conv2D(channels[0], 3, 1, 1, use_bias=False))
self.features.add(nn.BatchNorm())
for i, num_layer in enumerate(layers):
stride = 1 if i == 0 else 2
self.features.add(self._make_layer(block, num_layer, channels[i+1],
stride, i+1, in_channels=channels[i]))
self.features.add(nn.GlobalAvgPool2D())
self.output = nn.Dense(classes, in_units=channels[-1])
def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0):
layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
with layer.name_scope():
layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels,
prefix=''))
for _ in range(layers-1):
layer.add(block(channels, 1, False, in_channels=channels, prefix=''))
return layer
def hybrid_forward(self, F, x):
x = self.features(x)
x = self.output(x)
return x
class CIFARResNetV2(HybridBlock):
r"""ResNet V2 model from
`"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
block : HybridBlock
Class for the residual block. Options are CIFARBasicBlockV1, BottleneckV1.
layers : list of int
Numbers of layers in each block
channels : list of int
Numbers of channels in each block. Length should be one larger than layers list.
classes : int, default 1000
Number of classification classes.
"""
def __init__(self, block, layers, channels, classes=10, **kwargs):
super(CIFARResNetV2, self).__init__(**kwargs)
assert len(layers) == len(channels) - 1
with self.name_scope():
self.features = nn.HybridSequential(prefix='')
self.features.add(nn.BatchNorm(scale=False, center=False))
self.features.add(nn.Conv2D(channels[0], 3, 1, 1, use_bias=False))
in_channels = channels[0]
for i, num_layer in enumerate(layers):
stride = 1 if i == 0 else 2
self.features.add(self._make_layer(block, num_layer, channels[i+1],
stride, i+1, in_channels=in_channels))
in_channels = channels[i+1]
self.features.add(nn.BatchNorm())
self.features.add(nn.Activation('relu'))
self.features.add(nn.AvgPool2D(8))
self.features.add(nn.Flatten())
self.output = nn.Dense(classes, in_units=in_channels)
def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0):
layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
with layer.name_scope():
layer.add(block(channels, stride, channels != in_channels, in_channels=in_channels,
prefix=''))
for _ in range(layers-1):
layer.add(block(channels, 1, False, in_channels=channels, prefix=''))
return layer
def hybrid_forward(self, F, x):
x = self.features(x)
x = self.output(x)
return x
# Specification
resnet_net_versions = [CIFARResNetV1, CIFARResNetV2]
resnet_block_versions = [CIFARBasicBlockV1, CIFARBasicBlockV2]
def _get_resnet_spec(num_layers):
assert (num_layers - 2) % 6 == 0
n = (num_layers - 2) // 6
channels = [16, 16, 32, 64]
layers = [n] * (len(channels) - 1)
return layers, channels
# Constructor
[docs]def get_cifar_resnet(version, num_layers, pretrained=False, ctx=cpu(),
root=os.path.join('~', '.mxnet', 'models'), **kwargs):
r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
version : int
Version of ResNet. Options are 1, 2.
num_layers : int
Numbers of layers. Needs to be an integer in the form of 6*n+2, e.g. 20, 56, 110, 164.
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
layers, channels = _get_resnet_spec(num_layers)
resnet_class = resnet_net_versions[version-1]
block_class = resnet_block_versions[version-1]
net = resnet_class(block_class, layers, channels, **kwargs)
if pretrained:
from .model_store import get_model_file
net.load_params(get_model_file('cifar_resnet%d_v%d'%(num_layers, version),
root=root), ctx=ctx)
return net
[docs]def cifar_resnet20_v1(**kwargs):
r"""ResNet-20 V1 model for CIFAR10 from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(1, 20, **kwargs)
[docs]def cifar_resnet56_v1(**kwargs):
r"""ResNet-56 V1 model for CIFAR10 from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(1, 56, **kwargs)
[docs]def cifar_resnet110_v1(**kwargs):
r"""ResNet-110 V1 model for CIFAR10 from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(1, 110, **kwargs)
[docs]def cifar_resnet20_v2(**kwargs):
r"""ResNet-20 V2 model for CIFAR10 from `"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(2, 20, **kwargs)
[docs]def cifar_resnet56_v2(**kwargs):
r"""ResNet-56 V2 model for CIFAR10 from `"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(2, 56, **kwargs)
[docs]def cifar_resnet110_v2(**kwargs):
r"""ResNet-110 V2 model for CIFAR10 from `"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
pretrained : bool, default False
Whether to load the pretrained weights for model.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
"""
return get_cifar_resnet(2, 110, **kwargs)