torchsight.retrievers.resnet module

A module for Resnet retrievers.

Source code
"""A module for Resnet retrievers."""
from torchsight.models import ResnetDetector
from torchsight.transforms.augmentation import AugmentDetection
from torchsight.utils import JsonObject

from .slow import SlowInstanceRetriver


class ResnetRetriever(SlowInstanceRetriver):
    """A retriever that uses the dummy Resnet object detector."""

    def __init__(self, *args, params=None, **kwargs):
        """Initialize the retriver.

        Arguments:
            params (JsonObject or dict, optional): The parameters for the model and the transforms.

            The rest of the arguments are the same as the SlowInstanceRetriever, only the distance
            is fixed to 'l2'.
        """
        self.params = self.get_params().merge(params)
        super().__init__(*args, **kwargs, distance='l2')

    @staticmethod
    def get_params():
        """Get the base params for the model.

        Returns:
            JsonObject: with the parameters for the model.
        """
        return JsonObject({
            'model': {
                'resnet': 18,
                'dim': 512,
                'pool': 'avg',
                'kernels': [2, 4, 8, 16]
            },
            'transform': {
                'LongestMaxSize': {
                    'max_size': 512
                },
                'PadIfNeeded': {
                    'min_height': 512,
                    'min_width': 512
                }
            }
        })

    def _get_model(self):
        """Get the ResnetDetector."""
        return ResnetDetector(**self.params.model)

    def _get_transforms(self):
        """Get the transformations to apply to the images in the dataset and in the queries.

        Returns:
            callable: a transformation for only images (the images where we are going to search).
            callable: a transformation for images and bounding boxes (the query images with their
                bounding boxes indicating the instances to search).
        """
        transform = AugmentDetection(self.params.transform, evaluation=True, normalize=True)

        return transform, transform

Classes

class ResnetRetriever (ancestors: SlowInstanceRetriver, InstanceRetriever, PrintMixin)

A retriever that uses the dummy Resnet object detector.

Source code
class ResnetRetriever(SlowInstanceRetriver):
    """A retriever that uses the dummy Resnet object detector."""

    def __init__(self, *args, params=None, **kwargs):
        """Initialize the retriver.

        Arguments:
            params (JsonObject or dict, optional): The parameters for the model and the transforms.

            The rest of the arguments are the same as the SlowInstanceRetriever, only the distance
            is fixed to 'l2'.
        """
        self.params = self.get_params().merge(params)
        super().__init__(*args, **kwargs, distance='l2')

    @staticmethod
    def get_params():
        """Get the base params for the model.

        Returns:
            JsonObject: with the parameters for the model.
        """
        return JsonObject({
            'model': {
                'resnet': 18,
                'dim': 512,
                'pool': 'avg',
                'kernels': [2, 4, 8, 16]
            },
            'transform': {
                'LongestMaxSize': {
                    'max_size': 512
                },
                'PadIfNeeded': {
                    'min_height': 512,
                    'min_width': 512
                }
            }
        })

    def _get_model(self):
        """Get the ResnetDetector."""
        return ResnetDetector(**self.params.model)

    def _get_transforms(self):
        """Get the transformations to apply to the images in the dataset and in the queries.

        Returns:
            callable: a transformation for only images (the images where we are going to search).
            callable: a transformation for images and bounding boxes (the query images with their
                bounding boxes indicating the instances to search).
        """
        transform = AugmentDetection(self.params.transform, evaluation=True, normalize=True)

        return transform, transform

Static methods

def get_params()

Get the base params for the model.

Returns

JsonObject
with the parameters for the model.
Source code
@staticmethod
def get_params():
    """Get the base params for the model.

    Returns:
        JsonObject: with the parameters for the model.
    """
    return JsonObject({
        'model': {
            'resnet': 18,
            'dim': 512,
            'pool': 'avg',
            'kernels': [2, 4, 8, 16]
        },
        'transform': {
            'LongestMaxSize': {
                'max_size': 512
            },
            'PadIfNeeded': {
                'min_height': 512,
                'min_width': 512
            }
        }
    })

Methods

def __init__(self, *args, params=None, **kwargs)

Initialize the retriver.

Arguments

params : JsonObject or dict, optional
The parameters for the model and the transforms.

The rest of the arguments are the same as the SlowInstanceRetriever, only the distance is fixed to 'l2'.

Source code
def __init__(self, *args, params=None, **kwargs):
    """Initialize the retriver.

    Arguments:
        params (JsonObject or dict, optional): The parameters for the model and the transforms.

        The rest of the arguments are the same as the SlowInstanceRetriever, only the distance
        is fixed to 'l2'.
    """
    self.params = self.get_params().merge(params)
    super().__init__(*args, **kwargs, distance='l2')

Inherited members