torchsight.retrievers.dldenet
module
A module with a retriever based on the DLDENet.
Source code
"""A module with a retriever based on the DLDENet."""
import torch
from torchsight.models.dlde.extractor import DLDENetExtractor
from torchsight.transforms.augmentation import AugmentDetection
from torchsight.utils import JsonObject
from .slow import SlowInstanceRetriver
class DLDENetRetriever(SlowInstanceRetriver):
"""A retriever that uses the DLDENet extractor."""
def __init__(self, checkpoint, *args, params=None, device=None, **kwargs):
"""Initialize the retriver.
Arguments:
params (JsonObject or dict, optional): The parameters for the model and the transforms.
The rest of the arguments are the same as the SlowInstanceRetriever, only the distance
is fixed to 'cos'.
"""
self.params = self.get_params().merge(params)
self.checkpoint = checkpoint
self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu'
super().__init__(*args, **kwargs, distance='cos')
@staticmethod
def get_params():
"""Get the base params for the model.
Returns:
JsonObject: with the parameters for the model.
"""
return JsonObject({
'transform': {
'LongestMaxSize': {
'max_size': 512
},
'PadIfNeeded': {
'min_height': 512,
'min_width': 512
}
}
})
def _get_model(self):
"""Get the ResnetDetector."""
return DLDENetExtractor.from_checkpoint(self.checkpoint, self.device)
def _get_transforms(self):
"""Get the transformations to apply to the images in the dataset and in the queries.
Returns:
callable: a transformation for only images (the images where we are going to search).
callable: a transformation for images and bounding boxes (the query images with their
bounding boxes indicating the instances to search).
"""
transform = AugmentDetection(self.params.transform, evaluation=True, normalize=True)
return transform, transform
Classes
class DLDENetRetriever (ancestors: SlowInstanceRetriver, InstanceRetriever, PrintMixin)
-
A retriever that uses the DLDENet extractor.
Source code
class DLDENetRetriever(SlowInstanceRetriver): """A retriever that uses the DLDENet extractor.""" def __init__(self, checkpoint, *args, params=None, device=None, **kwargs): """Initialize the retriver. Arguments: params (JsonObject or dict, optional): The parameters for the model and the transforms. The rest of the arguments are the same as the SlowInstanceRetriever, only the distance is fixed to 'cos'. """ self.params = self.get_params().merge(params) self.checkpoint = checkpoint self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu' super().__init__(*args, **kwargs, distance='cos') @staticmethod def get_params(): """Get the base params for the model. Returns: JsonObject: with the parameters for the model. """ return JsonObject({ 'transform': { 'LongestMaxSize': { 'max_size': 512 }, 'PadIfNeeded': { 'min_height': 512, 'min_width': 512 } } }) def _get_model(self): """Get the ResnetDetector.""" return DLDENetExtractor.from_checkpoint(self.checkpoint, self.device) def _get_transforms(self): """Get the transformations to apply to the images in the dataset and in the queries. Returns: callable: a transformation for only images (the images where we are going to search). callable: a transformation for images and bounding boxes (the query images with their bounding boxes indicating the instances to search). """ transform = AugmentDetection(self.params.transform, evaluation=True, normalize=True) return transform, transform
Static methods
def get_params()
-
Get the base params for the model.
Returns
JsonObject
- with the parameters for the model.
Source code
@staticmethod def get_params(): """Get the base params for the model. Returns: JsonObject: with the parameters for the model. """ return JsonObject({ 'transform': { 'LongestMaxSize': { 'max_size': 512 }, 'PadIfNeeded': { 'min_height': 512, 'min_width': 512 } } })
Methods
def __init__(self, checkpoint, *args, params=None, device=None, **kwargs)
-
Initialize the retriver.
Arguments
params
:JsonObject
ordict
, optional- The parameters for the model and the transforms.
The rest of the arguments are the same as the SlowInstanceRetriever, only the distance is fixed to 'cos'.
Source code
def __init__(self, checkpoint, *args, params=None, device=None, **kwargs): """Initialize the retriver. Arguments: params (JsonObject or dict, optional): The parameters for the model and the transforms. The rest of the arguments are the same as the SlowInstanceRetriever, only the distance is fixed to 'cos'. """ self.params = self.get_params().merge(params) self.checkpoint = checkpoint self.device = device if device is not None else 'cuda:0' if torch.cuda.is_available() else 'cpu' super().__init__(*args, **kwargs, distance='cos')
Inherited members