.. _sphx_glr_build_examples_detection_train_ssd_voc.py: 2. Train SSD on Pascal VOC dataset ====================================== This tutorial goes through the basic building blocks of object detection provided by GluonCV. Specifically, we show how to build a state-of-the-art Single Shot Multibox Detection [Liu16]_ model by stacking GluonCV components. This is also a good starting point for your own object detection project. .. hint:: You can skip the rest of this tutorial and start training your SSD model right away by downloading this script: :download:`Download train_ssd.py<../../../scripts/detection/ssd/train_ssd.py>` Example usage: Train a default vgg16_atrous 300x300 model with Pascal VOC on GPU 0: .. code-block:: bash python train_ssd.py Train a resnet50_v1 512x512 model on GPU 0,1,2,3: .. code-block:: bash python train_ssd.py --gpus 0,1,2,3 --network resnet50_v1 --data-shape 512 Check the supported arguments: .. code-block:: bash python train_ssd.py --help Dataset ------- Please first go through this :ref:`sphx_glr_build_examples_datasets_pascal_voc.py` tutorial to setup Pascal VOC dataset on your disk. Then, we are ready to load training and validation images. .. code-block:: python from gluoncv.data import VOCDetection # typically we use 2007+2012 trainval splits for training data train_dataset = VOCDetection(splits=[(2007, 'trainval'), (2012, 'trainval')]) # and use 2007 test as validation data val_dataset = VOCDetection(splits=[(2007, 'test')]) print('Training images:', len(train_dataset)) print('Validation images:', len(val_dataset)) .. rst-class:: sphx-glr-script-out Out:: Training images: 16551 Validation images: 4952 Data transform ------------------ We can read an image-label pair from the training dataset: .. code-block:: python train_image, train_label = train_dataset[0] bboxes = train_label[:, :4] cids = train_label[:, 4:5] print('image:', train_image.shape) print('bboxes:', bboxes.shape, 'class ids:', cids.shape) .. rst-class:: sphx-glr-script-out Out:: image: (375, 500, 3) bboxes: (5, 4) class ids: (5, 1) Plot the image, together with the bounding box labels: .. code-block:: python from matplotlib import pyplot as plt from gluoncv.utils import viz ax = viz.plot_bbox(train_image.asnumpy(), bboxes, labels=cids, class_names=train_dataset.classes) plt.show() .. image:: /build/examples_detection/images/sphx_glr_train_ssd_voc_001.png :align: center Validation images are quite similar to training because they were basically split randomly to different sets .. code-block:: python val_image, val_label = val_dataset[0] bboxes = val_label[:, :4] cids = val_label[:, 4:5] ax = viz.plot_bbox(val_image.asnumpy(), bboxes, labels=cids, class_names=train_dataset.classes) plt.show() .. image:: /build/examples_detection/images/sphx_glr_train_ssd_voc_002.png :align: center For SSD networks, it is critical to apply data augmentation (see explanations in paper [Liu16]_). We provide tons of image and bounding box transform functions to do that. They are very convenient to use as well. .. code-block:: python from gluoncv.data.transforms import presets from gluoncv import utils from mxnet import nd .. code-block:: python width, height = 512, 512 # suppose we use 512 as base training size train_transform = presets.ssd.SSDDefaultTrainTransform(width, height) val_transform = presets.ssd.SSDDefaultValTransform(width, height) .. code-block:: python utils.random.seed(233) # fix seed in this tutorial apply transforms to train image .. code-block:: python train_image2, train_label2 = train_transform(train_image, train_label) print('tensor shape:', train_image2.shape) .. rst-class:: sphx-glr-script-out Out:: tensor shape: (3, 512, 512) Images in tensor are distorted because they no longer sit in (0, 255) range. Let's convert them back so we can see them clearly. .. code-block:: python train_image2 = train_image2.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array((0.485, 0.456, 0.406)) train_image2 = (train_image2 * 255).clip(0, 255) ax = viz.plot_bbox(train_image2.asnumpy(), train_label2[:, :4], labels=train_label2[:, 4:5], class_names=train_dataset.classes) plt.show() .. image:: /build/examples_detection/images/sphx_glr_train_ssd_voc_003.png :align: center apply transforms to validation image .. code-block:: python val_image2, val_label2 = val_transform(val_image, val_label) val_image2 = val_image2.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array((0.485, 0.456, 0.406)) val_image2 = (val_image2 * 255).clip(0, 255) ax = viz.plot_bbox(val_image2.clip(0, 255).asnumpy(), val_label2[:, :4], labels=val_label2[:, 4:5], class_names=train_dataset.classes) plt.show() .. image:: /build/examples_detection/images/sphx_glr_train_ssd_voc_004.png :align: center Transforms used in training include random expanding, random cropping, color distortion, random flipping, etc. In comparison, validation transforms are simpler and only resizing and color normalization is used. Data Loader ------------------ We will iterate through the entire dataset many times during training. Keep in mind that raw images have to be transformed to tensors (mxnet uses BCHW format) before they are fed into neural networks. In addition, to be able to run in mini-batches, images must be resized to the same shape. .. code-block:: python # A handy DataLoader would be very convenient for us to apply different transforms and aggregate data into mini-batches. # Because the number of objects varys a lot across images, we also have # varying label sizes. As a result, we need to pad those labels to the same size. # To deal with this problem, GluonCV provides DetectionDataLoader, # which handles padding automatically. from gluoncv.data import DetectionDataLoader batch_size = 4 # for tutorial, we use smaller batch-size num_workers = 0 # you can make it larger(if your CPU has more cores) to accelerate data loading train_loader = DetectionDataLoader(train_dataset.transform(train_transform), batch_size, shuffle=True, last_batch='rollover', num_workers=num_workers) val_loader = DetectionDataLoader(val_dataset.transform(val_transform), batch_size, shuffle=False, last_batch='keep', num_workers=num_workers) for ib, batch in enumerate(train_loader): if ib > 5: break print('data:', batch[0].shape, 'label:', batch[1].shape) .. rst-class:: sphx-glr-script-out Out:: data: (4, 3, 512, 512) label: (4, 4, 6) data: (4, 3, 512, 512) label: (4, 4, 6) data: (4, 3, 512, 512) label: (4, 7, 6) data: (4, 3, 512, 512) label: (4, 2, 6) data: (4, 3, 512, 512) label: (4, 8, 6) data: (4, 3, 512, 512) label: (4, 6, 6) SSD Network ------------------ GluonCV's SSD implementation is a composite Gluon HybridBlock (which means it can be exported to symbol to run in C++, Scala and other language bindings. We will cover this usage in future tutorials). In terms of structure, SSD networks are composed of base feature extraction network, anchor generators, class predictors and bounding box offset predictors. .. code-block:: python # For more details on how SSD detector works, please refer to our introductory # [tutorial](http://gluon.mxnet.io/chapter08_computer-vision/object-detection.html) # You can also refer to the original paper to learn more about the intuitions # behind SSD. # `Gluon Model Zoo <../../model_zoo/index.html>`__ has a lot of built-in SSD networks. # You can load your favorate one with one simple line of code: from gluoncv import model_zoo net = model_zoo.get_model('ssd_300_vgg16_atrous_voc', pretrained_base=False) print(net) .. rst-class:: sphx-glr-script-out Out:: SSD( (features): VGGAtrousExtractor( (stages): HybridSequential( (0): HybridSequential( (0): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): Activation(relu) ) (1): HybridSequential( (0): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): Activation(relu) ) (2): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): Activation(relu) (4): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): Activation(relu) ) (3): HybridSequential( (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): Activation(relu) (4): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): Activation(relu) ) (4): HybridSequential( (0): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): Activation(relu) (4): Conv2D(None -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (5): Activation(relu) ) (5): HybridSequential( (0): Conv2D(None -> 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6)) (1): Activation(relu) (2): Conv2D(None -> 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Activation(relu) ) ) (norm4): Normalize( ) (extras): HybridSequential( (0): HybridSequential( (0): Conv2D(None -> 256, kernel_size=(1, 1), stride=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): Activation(relu) ) (1): HybridSequential( (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): Activation(relu) ) (2): HybridSequential( (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1)) (3): Activation(relu) ) (3): HybridSequential( (0): Conv2D(None -> 128, kernel_size=(1, 1), stride=(1, 1)) (1): Activation(relu) (2): Conv2D(None -> 256, kernel_size=(3, 3), stride=(1, 1)) (3): Activation(relu) ) ) ) (class_predictors): HybridSequential( (0): ConvPredictor( (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (1): ConvPredictor( (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (2): ConvPredictor( (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (3): ConvPredictor( (predictor): Conv2D(None -> 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (4): ConvPredictor( (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (5): ConvPredictor( (predictor): Conv2D(None -> 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (box_predictors): HybridSequential( (0): ConvPredictor( (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (1): ConvPredictor( (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (2): ConvPredictor( (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (3): ConvPredictor( (predictor): Conv2D(None -> 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (4): ConvPredictor( (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (5): ConvPredictor( (predictor): Conv2D(None -> 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (anchor_generators): HybridSequential( (0): SSDAnchorGenerator( ) (1): SSDAnchorGenerator( ) (2): SSDAnchorGenerator( ) (3): SSDAnchorGenerator( ) (4): SSDAnchorGenerator( ) (5): SSDAnchorGenerator( ) ) (bbox_decoder): NormalizedBoxCenterDecoder( ) (cls_decoder): MultiPerClassDecoder( ) ) SSD network is a HybridBlock as mentioned before. You can call it with an input as: .. code-block:: python import mxnet as mx x = mx.nd.zeros(shape=(1, 3, 300, 300)) net.initialize() cids, scores, bboxes = net(x) SSD returns three values, where ``cids`` are the class labels, ``scores`` are confidence scores of each prediction, and ``bboxes`` are absolute coordinates of corresponding bounding boxes. Training targets ------------------ Unlike a single ``SoftmaxCrossEntropyLoss`` used in image classification, the loss used in SSD is more complicated. Don't worry though, because we have these modules available out of the box. Checkout the ``target_generator`` in SSD networks. .. code-block:: python print(net.target_generator) .. rst-class:: sphx-glr-script-out Out:: SSDTargetGenerator( (_matcher): CompositeMatcher( ) (_sampler): OHEMSampler( ) (_cls_encoder): MultiClassEncoder( ) (_box_encoder): NormalizedBoxCenterEncoder( (corner_to_center): BBoxCornerToCenter( ) ) (_center_to_corner): BBoxCenterToCorner( ) ) You can observe that there are: - A bounding box encoder which transfers raw coordinates to bbox prediction targets. - A class encoder which generates class labels for each anchor box. - Matcher and samplers used to apply various advanced strategies described in paper. .. hint:: Please checkout the full :download:`training script <../../../scripts/detection/ssd/train_ssd.py>` for complete implementation. References ---------- .. [Liu16] Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang Fu, Alexander C. Berg. SSD: Single Shot MultiBox Detector. ECCV 2016. **Total running time of the script:** ( 0 minutes 4.083 seconds) .. only :: html .. container:: sphx-glr-footer .. container:: sphx-glr-download :download:`Download Python source code: train_ssd_voc.py ` .. container:: sphx-glr-download :download:`Download Jupyter notebook: train_ssd_voc.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_