Source code for gluoncv.model_zoo.matchers

# pylint: disable=arguments-differ
"""Matchers for target assignment.
Matchers are commonly used in object-detection for anchor-groundtruth matching.
The matching process is a prerequisite to training target assignment.
Matching is usually not required during testing.
"""
from __future__ import absolute_import
from mxnet import gluon


[docs]class CompositeMatcher(gluon.HybridBlock): """A Matcher that combines multiple strategies. Parameters ---------- matchers : list of Matcher Matcher is a Block/HybridBlock used to match two groups of boxes """ def __init__(self, matchers): super(CompositeMatcher, self).__init__() assert len(matchers) > 0, "At least one matcher required." for matcher in matchers: assert isinstance(matcher, (gluon.Block, gluon.HybridBlock)) self._matchers = matchers
[docs] def hybrid_forward(self, F, x): matches = [matcher(x) for matcher in self._matchers] return self._compose_matches(F, matches)
def _compose_matches(self, F, matches): """Given multiple match results, compose the final match results. The order of matches matters. Only the unmatched(-1s) in the current state will be substituded with the matching in the rest matches. Parameters ---------- matches : list of NDArrays N match results, each is an output of a different Matcher Returns ------- one match results as (B, N, M) NDArray """ result = matches[0] for match in matches[1:]: result = F.where(result > -0.5, result, match) return result
[docs]class BipartiteMatcher(gluon.HybridBlock): """A Matcher implementing bipartite matching strategy. Parameters ---------- threshold : float Threshold used to ignore invalid paddings is_ascend : bool Whether sort matching order in ascending order. Default is False. """ def __init__(self, threshold=1e-12, is_ascend=False): super(BipartiteMatcher, self).__init__() self._threshold = threshold self._is_ascend = is_ascend
[docs] def hybrid_forward(self, F, x): match = F.contrib.bipartite_matching(x, threshold=self._threshold, is_ascend=self._is_ascend) return match[0]
[docs]class MaximumMatcher(gluon.HybridBlock): """A Matcher implementing maximum matching strategy. Parameters ---------- threshold : float Matching threshold. """ def __init__(self, threshold): super(MaximumMatcher, self).__init__() self._threshold = threshold
[docs] def hybrid_forward(self, F, x): argmax = F.argmax(x, axis=-1) match = F.where(F.pick(x, argmax, axis=-1) > self._threshold, argmax, F.ones_like(argmax) * -1) return match