torchsight.retrievers.datasets
module
Datasets for the InstanceRetrievers.
Source code
"""Datasets for the InstanceRetrievers."""
import os
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
class ImagesDataset(Dataset):
"""A dataset to load the images."""
def __init__(self, root=None, paths=None, transform=None, extensions=None):
"""Initialize the dataset.
You must provide the root of the directory that contains the images or the paths of the images.
Arguments:
root (str): The path to the root directory that contains the images
to generate the database.
paths (list of str): The list with the path of images where to search.
transform (callable, optional): The transform to apply to the image.
extensions (list of str): If given it will load only files with the
given extensions.
"""
if root is None and paths is None:
raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.')
self.root = root
self.transform = transform
self.extensions = extensions
if extensions is not None:
self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions]
self.images = paths if paths is not None else self.get_images_paths()
def __len__(self):
return len(self.images)
def __getitem__(self, i):
"""Load an image.
Arguments:
i (int): The index of the image to load.
Returns:
image: The image loaded and transformed.
"""
path = self.images[i]
image = Image.open(path)
if self.transform is not None:
image = self.transform({'image': image})
return image, path
def get_images_paths(self):
"""Get all the paths of the images that are in the given directory
and its subdirectories.
"""
if not os.path.exists(self.root):
raise ValueError('The directory "{}" does not exists.'.format(self.root))
images = []
for dirpath, _, files in os.walk(self.root):
images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)]
return images
def is_valid(self, file):
"""Check if the file has a correct extension. If we don't have extensions to check
it always returns True.
Arguments:
file (str): The file's name to check.
"""
if self.extensions is None:
return True
return any((file.endswith(ext) for ext in self.extensions))
def get_dataloader(self, batch_size, num_workers):
"""Get the dataloader for this dataset.
Returns:
DataLoader: the dataloader using the given parameters.
"""
def collate(items):
images = [item[0] for item in items]
paths = [item[1] for item in items]
if torch.is_tensor(images[0]):
max_width = max([image.shape[2] for image in images])
max_height = max([image.shape[1] for image in images])
def pad_image(image):
aux = torch.zeros((image.shape[0], max_height, max_width))
aux[:, :image.shape[1], :image.shape[2]] = image
return aux
images = torch.stack([pad_image(image) for image in images], dim=0)
return images, paths
return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)
Classes
class ImagesDataset (ancestors: torch.utils.data.dataset.Dataset)
-
A dataset to load the images.
Source code
class ImagesDataset(Dataset): """A dataset to load the images.""" def __init__(self, root=None, paths=None, transform=None, extensions=None): """Initialize the dataset. You must provide the root of the directory that contains the images or the paths of the images. Arguments: root (str): The path to the root directory that contains the images to generate the database. paths (list of str): The list with the path of images where to search. transform (callable, optional): The transform to apply to the image. extensions (list of str): If given it will load only files with the given extensions. """ if root is None and paths is None: raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.') self.root = root self.transform = transform self.extensions = extensions if extensions is not None: self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions] self.images = paths if paths is not None else self.get_images_paths() def __len__(self): return len(self.images) def __getitem__(self, i): """Load an image. Arguments: i (int): The index of the image to load. Returns: image: The image loaded and transformed. """ path = self.images[i] image = Image.open(path) if self.transform is not None: image = self.transform({'image': image}) return image, path def get_images_paths(self): """Get all the paths of the images that are in the given directory and its subdirectories. """ if not os.path.exists(self.root): raise ValueError('The directory "{}" does not exists.'.format(self.root)) images = [] for dirpath, _, files in os.walk(self.root): images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)] return images def is_valid(self, file): """Check if the file has a correct extension. If we don't have extensions to check it always returns True. Arguments: file (str): The file's name to check. """ if self.extensions is None: return True return any((file.endswith(ext) for ext in self.extensions)) def get_dataloader(self, batch_size, num_workers): """Get the dataloader for this dataset. Returns: DataLoader: the dataloader using the given parameters. """ def collate(items): images = [item[0] for item in items] paths = [item[1] for item in items] if torch.is_tensor(images[0]): max_width = max([image.shape[2] for image in images]) max_height = max([image.shape[1] for image in images]) def pad_image(image): aux = torch.zeros((image.shape[0], max_height, max_width)) aux[:, :image.shape[1], :image.shape[2]] = image return aux images = torch.stack([pad_image(image) for image in images], dim=0) return images, paths return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)
Methods
def __init__(self, root=None, paths=None, transform=None, extensions=None)
-
Initialize the dataset.
You must provide the root of the directory that contains the images or the paths of the images.
Arguments
root
:str
- The path to the root directory that contains the images to generate the database.
paths
:list
ofstr
- The list with the path of images where to search.
transform
:callable
, optional- The transform to apply to the image.
extensions
:list
ofstr
- If given it will load only files with the given extensions.
Source code
def __init__(self, root=None, paths=None, transform=None, extensions=None): """Initialize the dataset. You must provide the root of the directory that contains the images or the paths of the images. Arguments: root (str): The path to the root directory that contains the images to generate the database. paths (list of str): The list with the path of images where to search. transform (callable, optional): The transform to apply to the image. extensions (list of str): If given it will load only files with the given extensions. """ if root is None and paths is None: raise ValueError('You must provide the "root" directory of the images or the "paths" of the images.') self.root = root self.transform = transform self.extensions = extensions if extensions is not None: self.extensions = extensions if isinstance(extensions, (list, tuple)) else [extensions] self.images = paths if paths is not None else self.get_images_paths()
def get_dataloader(self, batch_size, num_workers)
-
Get the dataloader for this dataset.
Returns
DataLoader
- the dataloader using the given parameters.
Source code
def get_dataloader(self, batch_size, num_workers): """Get the dataloader for this dataset. Returns: DataLoader: the dataloader using the given parameters. """ def collate(items): images = [item[0] for item in items] paths = [item[1] for item in items] if torch.is_tensor(images[0]): max_width = max([image.shape[2] for image in images]) max_height = max([image.shape[1] for image in images]) def pad_image(image): aux = torch.zeros((image.shape[0], max_height, max_width)) aux[:, :image.shape[1], :image.shape[2]] = image return aux images = torch.stack([pad_image(image) for image in images], dim=0) return images, paths return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=collate)
def get_images_paths(self)
-
Get all the paths of the images that are in the given directory and its subdirectories.
Source code
def get_images_paths(self): """Get all the paths of the images that are in the given directory and its subdirectories. """ if not os.path.exists(self.root): raise ValueError('The directory "{}" does not exists.'.format(self.root)) images = [] for dirpath, _, files in os.walk(self.root): images += [os.path.join(dirpath, file) for file in files if self.is_valid(file)] return images
def is_valid(self, file)
-
Check if the file has a correct extension. If we don't have extensions to check it always returns True.
Arguments
file
:str
- The file's name to check.
Source code
def is_valid(self, file): """Check if the file has a correct extension. If we don't have extensions to check it always returns True. Arguments: file (str): The file's name to check. """ if self.extensions is None: return True return any((file.endswith(ext) for ext in self.extensions))