torchsight.retrievers.datasets module

Datasets for the InstanceRetrievers.

Source code
"""Datasets for the InstanceRetrievers."""
import os

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset


class ImagesDataset(Dataset):
    """A dataset to load the images."""

    def __init__(self, root=None, paths=None, transform=None, extensions=None):
        """Initialize the dataset.

        You must provide the root of the directory that contains the images or the paths of the images.

        Arguments:
            root (str): The path to the root directory that contains the images
                to generate the database.
            paths (list of str): The list with the path of images where to search.
            transform (callable, optional): The transform to apply to the image.
            extensions (list of str): If given it will load only files with the
                given extensions.
        """
        if root is None and paths is None:
            raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.')

        self.root = root
        self.transform = transform
        self.extensions = extensions
        if extensions is not None:
            self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions]
        self.images = paths if paths is not None else self.get_images_paths()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        """Load an image.

        Arguments:
            i (int): The index of the image to load.

        Returns:
            image: The image loaded and transformed.
        """
        path = self.images[i]
        image = Image.open(path)

        if self.transform is not None:
            image = self.transform({'image': image})

        return image, path

    def get_images_paths(self):
        """Get all the paths of the images that are in the given directory
        and its subdirectories.
        """
        if not os.path.exists(self.root):
            raise ValueError('The directory "{}" does not exists.'.format(self.root))

        images = []
        for dirpath, _, files in os.walk(self.root):
            images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)]

        return images

    def is_valid(self, file):
        """Check if the file has a correct extension. If we don't have extensions to check
        it always returns True.

        Arguments:
            file (str): The file's name to check.
        """
        if self.extensions is None:
            return True

        return any((file.endswith(ext) for ext in self.extensions))

    def get_dataloader(self, batch_size, num_workers):
        """Get the dataloader for this dataset.

        Returns:
            DataLoader: the dataloader using the given parameters.
        """
        def collate(items):
            images = [item[0] for item in items]
            paths = [item[1] for item in items]

            if torch.is_tensor(images[0]):
                max_width = max([image.shape[2] for image in images])
                max_height = max([image.shape[1] for image in images])

                def pad_image(image):
                    aux = torch.zeros((image.shape[0], max_height, max_width))
                    aux[:, :image.shape[1], :image.shape[2]] = image
                    return aux

                images = torch.stack([pad_image(image) for image in images], dim=0)

            return images, paths

        return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)

Classes

class ImagesDataset (ancestors: torch.utils.data.dataset.Dataset)

A dataset to load the images.

Source code
class ImagesDataset(Dataset):
    """A dataset to load the images."""

    def __init__(self, root=None, paths=None, transform=None, extensions=None):
        """Initialize the dataset.

        You must provide the root of the directory that contains the images or the paths of the images.

        Arguments:
            root (str): The path to the root directory that contains the images
                to generate the database.
            paths (list of str): The list with the path of images where to search.
            transform (callable, optional): The transform to apply to the image.
            extensions (list of str): If given it will load only files with the
                given extensions.
        """
        if root is None and paths is None:
            raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.')

        self.root = root
        self.transform = transform
        self.extensions = extensions
        if extensions is not None:
            self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions]
        self.images = paths if paths is not None else self.get_images_paths()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        """Load an image.

        Arguments:
            i (int): The index of the image to load.

        Returns:
            image: The image loaded and transformed.
        """
        path = self.images[i]
        image = Image.open(path)

        if self.transform is not None:
            image = self.transform({'image': image})

        return image, path

    def get_images_paths(self):
        """Get all the paths of the images that are in the given directory
        and its subdirectories.
        """
        if not os.path.exists(self.root):
            raise ValueError('The directory "{}" does not exists.'.format(self.root))

        images = []
        for dirpath, _, files in os.walk(self.root):
            images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)]

        return images

    def is_valid(self, file):
        """Check if the file has a correct extension. If we don't have extensions to check
        it always returns True.

        Arguments:
            file (str): The file's name to check.
        """
        if self.extensions is None:
            return True

        return any((file.endswith(ext) for ext in self.extensions))

    def get_dataloader(self, batch_size, num_workers):
        """Get the dataloader for this dataset.

        Returns:
            DataLoader: the dataloader using the given parameters.
        """
        def collate(items):
            images = [item[0] for item in items]
            paths = [item[1] for item in items]

            if torch.is_tensor(images[0]):
                max_width = max([image.shape[2] for image in images])
                max_height = max([image.shape[1] for image in images])

                def pad_image(image):
                    aux = torch.zeros((image.shape[0], max_height, max_width))
                    aux[:, :image.shape[1], :image.shape[2]] = image
                    return aux

                images = torch.stack([pad_image(image) for image in images], dim=0)

            return images, paths

        return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)

Methods

def __init__(self, root=None, paths=None, transform=None, extensions=None)

Initialize the dataset.

You must provide the root of the directory that contains the images or the paths of the images.

Arguments

root : str
The path to the root directory that contains the images to generate the database.
paths : list of str
The list with the path of images where to search.
transform : callable, optional
The transform to apply to the image.
extensions : list of str
If given it will load only files with the given extensions.
Source code
def __init__(self, root=None, paths=None, transform=None, extensions=None):
    """Initialize the dataset.

    You must provide the root of the directory that contains the images or the paths of the images.

    Arguments:
        root (str): The path to the root directory that contains the images
            to generate the database.
        paths (list of str): The list with the path of images where to search.
        transform (callable, optional): The transform to apply to the image.
        extensions (list of str): If given it will load only files with the
            given extensions.
    """
    if root is None and paths is None:
        raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.')

    self.root = root
    self.transform = transform
    self.extensions = extensions
    if extensions is not None:
        self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions]
    self.images = paths if paths is not None else self.get_images_paths()
def get_dataloader(self, batch_size, num_workers)

Get the dataloader for this dataset.

Returns

DataLoader
the dataloader using the given parameters.
Source code
def get_dataloader(self, batch_size, num_workers):
    """Get the dataloader for this dataset.

    Returns:
        DataLoader: the dataloader using the given parameters.
    """
    def collate(items):
        images = [item[0] for item in items]
        paths = [item[1] for item in items]

        if torch.is_tensor(images[0]):
            max_width = max([image.shape[2] for image in images])
            max_height = max([image.shape[1] for image in images])

            def pad_image(image):
                aux = torch.zeros((image.shape[0], max_height, max_width))
                aux[:, :image.shape[1], :image.shape[2]] = image
                return aux

            images = torch.stack([pad_image(image) for image in images], dim=0)

        return images, paths

    return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)
def get_images_paths(self)

Get all the paths of the images that are in the given directory and its subdirectories.

Source code
def get_images_paths(self):
    """Get all the paths of the images that are in the given directory
    and its subdirectories.
    """
    if not os.path.exists(self.root):
        raise ValueError('The directory "{}" does not exists.'.format(self.root))

    images = []
    for dirpath, _, files in os.walk(self.root):
        images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)]

    return images
def is_valid(self, file)

Check if the file has a correct extension. If we don't have extensions to check it always returns True.

Arguments

file : str
The file's name to check.
Source code
def is_valid(self, file):
    """Check if the file has a correct extension. If we don't have extensions to check
    it always returns True.

    Arguments:
        file (str): The file's name to check.
    """
    if self.extensions is None:
        return True

    return any((file.endswith(ext) for ext in self.extensions))