# pylint: disable=wildcard-import, unused-wildcard-import
"""Model store which handles pretrained models from both
mxnet.gluon.model_zoo.vision and gluoncv.models
"""
from mxnet import gluon
from .ssd import *
from .fcn import *
from .cifarresnet import *
from .cifarwideresnet import *
__all__ = ['get_model']
[docs]def get_model(name, **kwargs):
"""Returns a pre-defined model by name
Parameters
----------
name : str
Name of the model.
pretrained : bool
Whether to load the pretrained weights for model.
classes : int
Number of classes for the output layer.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Returns
-------
HybridBlock
The model.
"""
models = {
'ssd_300_vgg16_atrous_voc': ssd_300_vgg16_atrous_voc,
'ssd_512_vgg16_atrous_voc': ssd_512_vgg16_atrous_voc,
'ssd_512_resnet18_v1_voc': ssd_512_resnet18_v1_voc,
'ssd_512_resnet50_v1_voc': ssd_512_resnet50_v1_voc,
'ssd_512_resnet101_v2_voc': ssd_512_resnet101_v2_voc,
'ssd_512_resnet152_v2_voc': ssd_512_resnet152_v2_voc,
'ssd_512_mobilenet1_0_voc': ssd_512_mobilenet1_0_voc,
'cifar_resnet20_v1': cifar_resnet20_v1,
'cifar_resnet56_v1': cifar_resnet56_v1,
'cifar_resnet110_v1': cifar_resnet110_v1,
'cifar_resnet20_v2': cifar_resnet20_v2,
'cifar_resnet56_v2': cifar_resnet56_v2,
'cifar_resnet110_v2': cifar_resnet110_v2,
'cifar_wideresnet16_10': cifar_wideresnet16_10,
'cifar_wideresnet28_10': cifar_wideresnet28_10,
'cifar_wideresnet40_8': cifar_wideresnet40_8,
'fcn_resnet50_voc' : get_fcn_voc_resnet50,
'fcn_resnet101_voc' : get_fcn_voc_resnet101,
}
try:
net = gluon.model_zoo.vision.get_model(name, **kwargs)
except ValueError as e:
name = name.lower()
if name not in models:
raise ValueError('%s\n\t%s' % (str(e), '\n\t'.join(sorted(models.keys()))))
net = models[name](**kwargs)
return net