torchsight.retrievers.slow module

Module with slow retrievers but with good memory footprint.

Source code
"""Module with slow retrievers but with good memory footprint."""
import time

import torch

from .retriever import InstanceRetriever


class SlowInstanceRetriver(InstanceRetriever):
    """An implementation of an InstanceRetriever that to not abuse of the memory of the server
    it computes the embeddings for all the images every time a query is made.

    The algorithm is like:
    - Query K nearest instances for Q objects.
    - Generate the Q embeddings.
    - Iterate through the images by batching getting the nearest embeddings to the objects and
      update the k nearest ones.

    Returns the final k*Q nearest instances.
    """

    def __init__(self, *args, distance='l2', **kwargs):
        """Initialize the retriver.

        Arguments:
            distance (str, optional): The distance to use.

            The rest of the arguments are the same as InstanceRetriever.
        """
        if distance not in ['l2', 'cos']:
            raise ValueError('Distance "{}" not supported. Availables: {}'.format(distance, ['l2', 'cos']))

        self._distance = self._l2_distance if distance == 'l2' else self._cos_distance

        super().__init__(*args, **kwargs)

    @staticmethod
    def _l2_distance(queries, embeddings):
        """Compute the L2 distance between the queries and the embeddings.

        Arguments:
            queries (torch.Tensor): with shape `(q, dim)`.
            embeddings (torch.Tensor): with shape `(b, e, dim)`.

        Returns:
            torch.Tensor: with the distances with shape `(q, b, e)`.
        """
        queries = queries.unsqueeze(dim=1).unsqueeze(dim=2)  # (q, 1, 1, dim)
        embeddings = embeddings.unsqueeze(dim=0)  # (1, b, e, dim)

        return ((queries - embeddings) ** 2).sum(dim=3).sqrt()  # (q, b, e)

    @staticmethod
    def _cos_distance(queries, embeddings):
        """Compute the cosine distance between the queries and the embeddings.

        Arguments:
            queries (torch.Tensor): with shape `(q, dim)`.
            embeddings (torch.Tensor): with shape `(b, e, dim)`.

        Returns:
            torch.Tensor: with the distances with shape `(q, b, e)`.
        """
        queries_norm = queries.norm(dim=1).unsqueeze(dim=1).unsqueeze(dim=2)          # (q, 1, 1)
        embeddings_norm = embeddings.norm(dim=2).unsqueeze(dim=0)                     # (1, b, e)
        norms = queries_norm * embeddings_norm                                        # (q, b, e)
        queries = queries.unsqueeze(dim=1).unsqueeze(dim=2).unsqueeze(dim=3)          # (q, 1, 1,   1, dim)
        embeddings = embeddings.unsqueeze(dim=0).unsqueeze(dim=3)                     # (1, b, e, dim,   1)
        similarity = torch.matmul(queries, embeddings).squeeze(dim=4).squeeze(dim=3)  # (q, b, e)
        similarity /= norms
        # TODO: Split this, it cannot handle that big matrixes multiplications

        return 1 - similarity

    def _search(self, queries, k):
        """Search in the dataset and get the tensor with the distances, bounding boxes and the paths
        of the images.

        Arguments:
            queries (torch.Tensor): the embeddings generated for each query object.
                Shape `(number of instances to search, embedding dim)`.
            k (int): The number of results to get.

        Returns:
            np.ndarray: The distances between the embedding queries and the found object in descendant order.
                So the nearest result to the embedding query `i` has distance `distance[i, 0]`, and so on.
                To get the distances between the `i` embedding and its `j` result you can do
                `distances[i, j]`.
                Shape `(num of query objects, k)`.
            np.ndarray: The bounding boxes for each result. Shape `(num of query objects, k, 4)`.
            list of list of str: A list with `len = len(images)` that contains the path for each
                one of the images where the object was found.
                If you want to know the path of the result object that is in the `k`-th position
                of the `i` embedding you can do `results_paths[i][k]`.
        """
        num_queries = queries.shape[0]
        distances = 1e8 * queries.new_ones(num_queries, k)
        boxes = queries.new_zeros(num_queries, k, 4)
        paths = [[None for _ in range(k)] for _ in range(num_queries)]

        num_batches = len(self.dataloader)
        total_imgs = 0
        init = time.time()

        with torch.no_grad():
            self.model.to(self.device)
            for i, (images, batch_paths) in enumerate(self.dataloader):
                batch_size = images.shape[0]
                images = images.to(self.device)
                embeddings, batch_boxes = self.model(images)  # (b, e, d), (b, e, 4)
                num_embeddings = embeddings.shape[1]
                actual_distances = self._distance(queries, embeddings)  # (q, b, e)

                # Iterate over the queries
                for q in range(num_queries):
                    # Iterate over the batch items
                    for b in range(batch_size):
                        # Iterate over the embeddings of a given batch item
                        for e in range(num_embeddings):
                            dis = actual_distances[q, b, e]
                            # Get the index by counting how many distances are below this one
                            index = (distances[q] < dis).sum()
                            if index >= k:
                                continue
                            distances[q, index] = dis
                            paths[q][index] = batch_paths[b]
                            boxes[q, index, :] = batch_boxes[b, e]

                # Show some stats about the progress
                total_imgs += images.shape[0]
                self.logger.log({
                    'Batch': '{}/{}'.format(i + 1, num_batches),
                    'Time': '{:.3f} s'.format(time.time() - init),
                    'Images': total_imgs
                })

        return distances, boxes, paths

Classes

class SlowInstanceRetriver (ancestors: InstanceRetriever, PrintMixin)

An implementation of an InstanceRetriever that to not abuse of the memory of the server it computes the embeddings for all the images every time a query is made.

The algorithm is like: - Query K nearest instances for Q objects. - Generate the Q embeddings. - Iterate through the images by batching getting the nearest embeddings to the objects and update the k nearest ones.

Returns the final k*Q nearest instances.

Source code
class SlowInstanceRetriver(InstanceRetriever):
    """An implementation of an InstanceRetriever that to not abuse of the memory of the server
    it computes the embeddings for all the images every time a query is made.

    The algorithm is like:
    - Query K nearest instances for Q objects.
    - Generate the Q embeddings.
    - Iterate through the images by batching getting the nearest embeddings to the objects and
      update the k nearest ones.

    Returns the final k*Q nearest instances.
    """

    def __init__(self, *args, distance='l2', **kwargs):
        """Initialize the retriver.

        Arguments:
            distance (str, optional): The distance to use.

            The rest of the arguments are the same as InstanceRetriever.
        """
        if distance not in ['l2', 'cos']:
            raise ValueError('Distance "{}" not supported. Availables: {}'.format(distance, ['l2', 'cos']))

        self._distance = self._l2_distance if distance == 'l2' else self._cos_distance

        super().__init__(*args, **kwargs)

    @staticmethod
    def _l2_distance(queries, embeddings):
        """Compute the L2 distance between the queries and the embeddings.

        Arguments:
            queries (torch.Tensor): with shape `(q, dim)`.
            embeddings (torch.Tensor): with shape `(b, e, dim)`.

        Returns:
            torch.Tensor: with the distances with shape `(q, b, e)`.
        """
        queries = queries.unsqueeze(dim=1).unsqueeze(dim=2)  # (q, 1, 1, dim)
        embeddings = embeddings.unsqueeze(dim=0)  # (1, b, e, dim)

        return ((queries - embeddings) ** 2).sum(dim=3).sqrt()  # (q, b, e)

    @staticmethod
    def _cos_distance(queries, embeddings):
        """Compute the cosine distance between the queries and the embeddings.

        Arguments:
            queries (torch.Tensor): with shape `(q, dim)`.
            embeddings (torch.Tensor): with shape `(b, e, dim)`.

        Returns:
            torch.Tensor: with the distances with shape `(q, b, e)`.
        """
        queries_norm = queries.norm(dim=1).unsqueeze(dim=1).unsqueeze(dim=2)          # (q, 1, 1)
        embeddings_norm = embeddings.norm(dim=2).unsqueeze(dim=0)                     # (1, b, e)
        norms = queries_norm * embeddings_norm                                        # (q, b, e)
        queries = queries.unsqueeze(dim=1).unsqueeze(dim=2).unsqueeze(dim=3)          # (q, 1, 1,   1, dim)
        embeddings = embeddings.unsqueeze(dim=0).unsqueeze(dim=3)                     # (1, b, e, dim,   1)
        similarity = torch.matmul(queries, embeddings).squeeze(dim=4).squeeze(dim=3)  # (q, b, e)
        similarity /= norms
        # TODO: Split this, it cannot handle that big matrixes multiplications

        return 1 - similarity

    def _search(self, queries, k):
        """Search in the dataset and get the tensor with the distances, bounding boxes and the paths
        of the images.

        Arguments:
            queries (torch.Tensor): the embeddings generated for each query object.
                Shape `(number of instances to search, embedding dim)`.
            k (int): The number of results to get.

        Returns:
            np.ndarray: The distances between the embedding queries and the found object in descendant order.
                So the nearest result to the embedding query `i` has distance `distance[i, 0]`, and so on.
                To get the distances between the `i` embedding and its `j` result you can do
                `distances[i, j]`.
                Shape `(num of query objects, k)`.
            np.ndarray: The bounding boxes for each result. Shape `(num of query objects, k, 4)`.
            list of list of str: A list with `len = len(images)` that contains the path for each
                one of the images where the object was found.
                If you want to know the path of the result object that is in the `k`-th position
                of the `i` embedding you can do `results_paths[i][k]`.
        """
        num_queries = queries.shape[0]
        distances = 1e8 * queries.new_ones(num_queries, k)
        boxes = queries.new_zeros(num_queries, k, 4)
        paths = [[None for _ in range(k)] for _ in range(num_queries)]

        num_batches = len(self.dataloader)
        total_imgs = 0
        init = time.time()

        with torch.no_grad():
            self.model.to(self.device)
            for i, (images, batch_paths) in enumerate(self.dataloader):
                batch_size = images.shape[0]
                images = images.to(self.device)
                embeddings, batch_boxes = self.model(images)  # (b, e, d), (b, e, 4)
                num_embeddings = embeddings.shape[1]
                actual_distances = self._distance(queries, embeddings)  # (q, b, e)

                # Iterate over the queries
                for q in range(num_queries):
                    # Iterate over the batch items
                    for b in range(batch_size):
                        # Iterate over the embeddings of a given batch item
                        for e in range(num_embeddings):
                            dis = actual_distances[q, b, e]
                            # Get the index by counting how many distances are below this one
                            index = (distances[q] < dis).sum()
                            if index >= k:
                                continue
                            distances[q, index] = dis
                            paths[q][index] = batch_paths[b]
                            boxes[q, index, :] = batch_boxes[b, e]

                # Show some stats about the progress
                total_imgs += images.shape[0]
                self.logger.log({
                    'Batch': '{}/{}'.format(i + 1, num_batches),
                    'Time': '{:.3f} s'.format(time.time() - init),
                    'Images': total_imgs
                })

        return distances, boxes, paths

Methods

def __init__(self, *args, distance='l2', **kwargs)

Initialize the retriver.

Arguments

distance : str, optional
The distance to use.

The rest of the arguments are the same as InstanceRetriever.

Source code
def __init__(self, *args, distance='l2', **kwargs):
    """Initialize the retriver.

    Arguments:
        distance (str, optional): The distance to use.

        The rest of the arguments are the same as InstanceRetriever.
    """
    if distance not in ['l2', 'cos']:
        raise ValueError('Distance "{}" not supported. Availables: {}'.format(distance, ['l2', 'cos']))

    self._distance = self._l2_distance if distance == 'l2' else self._cos_distance

    super().__init__(*args, **kwargs)

Inherited members