torchsight.retrievers.faiss
module
Retrievers using FAISS as database.
Source code
"""Retrievers using FAISS as database."""
import os
import time
import faiss
import torch
from .retriever import InstanceRetriever
class FaissInstanceRetriever(InstanceRetriever):
"""A retriver that looks for instance of objects in a database of images.
You must provide a model in the `get_model()` method.
You can call the `create_database()` method to create the database,
and then query instances of objects using the `query()` method.
"""
def __init__(self, *args, storage='./databases', index='IndexFlatIP', **kwargs):
"""Initialize the retriever.
Arguments:
storage (str, optional): The path to the directory where to store the data.
index (str, optional): The index of FAISS to use to store the embeddings.
The default one is the FlatIP (Inner Product) that performs the cosine distance
so your embeddings must be normalized beforehand.
You can find more indexes here:
https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
The rest of the parameters are the same as InstanceRetriever.
"""
self.database = None
self.storage = storage
self.embeddings_file = os.path.join(self.storage, 'embeddings.index')
self.boxes_file = os.path.join(self.storage, 'boxes.index')
self.index = index
self.embeddings = None # FAISS index for the embeddings
self.boxes = None # FAISS index for the boxes
self.paths = {} # A dict to map between FAISS ids and images' paths
##############################
### SETTERS ####
##############################
def _set_indexes(self, dim):
"""Set the FAISS index.
The embedding could have any size but the bounding boxes must have size 4.
Arguments:
dim (int): The dimension of the embeddings.
"""
if self.index == 'IndexFlatL2':
self.embeddings = faiss.IndexFlatL2(dim)
self.boxes = faiss.IndexFlatL2(4)
elif self.index == 'IndexFlatIP':
self.embeddings = faiss.IndexFlatIP(dim)
self.boxes = faiss.IndexFlatIP(4)
else:
raise ValueError('Index "{}" not supported.'.format(self.index))
######################################
### DATABASE METHODS ###
######################################
def create_database(self, batch_size=8, num_workers=8):
"""Generates the database to insert in an index of FAISS.
Arguments:
batch_size (int): The batch size to use to compute in parallel the images.
num_workers (int): The number of process to use to load the images and generate
the batches.
"""
self.print('Creating database ...')
dataloader = self.dataset.get_dataloader(batch_size, num_workers)
num_batches = len(dataloader)
total_embs = 0
total_imgs = 0
init = time.time()
with torch.no_grad():
for i, (images, paths) in enumerate(dataloader):
embeddings, boxes = self.model(images)
# Create the indexes if they are not created yet
if self.embeddings is None or self.boxes is None:
self._set_indexes(embeddings.shape[1])
# Add the vectors to the indexes
self.embeddings.add(embeddings)
self.boxes.add(boxes)
# Map the id of the vectors to their image path
for j, path in paths:
self.paths[(i*batch_size) + j] = path
# Show some stats about the progress
total_imgs += images.shape[0]
total_embs += embeddings.shape[0]
self.logger.log({
'Batch': '{}/{}'.format(i + 1, num_batches),
'Time': '{:.3f} s'.format(time.time() - init),
'Images': total_imgs,
'Embeddings': total_embs,
})
self.save()
def query(self, images, boxes=None, strategy='max_iou', k=100):
"""TODO:"""
raise NotImplementedError()
###################################
### SAVING/LOADING ###
###################################
def save(self):
"""Save the indexes in the storage directory."""
self.print('Saving indexes ...')
faiss.write_index(self.embeddings, self.embeddings_file)
faiss.write_index(self.boxes, self.boxes_file)
def load(self):
"""Load the indexes from the storage directory."""
self.print('Loading indexes ...')
if not os.path.exists(self.embeddings_file):
raise ValueError('There is no ')
self.embeddings = faiss.read_index(self.embeddings_file)
self.boxes = faiss.read_index(self.boxes_file)
Classes
class FaissInstanceRetriever (ancestors: InstanceRetriever, PrintMixin)
-
A retriver that looks for instance of objects in a database of images.
You must provide a model in the
get_model()
method.You can call the
create_database()
method to create the database, and then query instances of objects using thequery()
method.Source code
class FaissInstanceRetriever(InstanceRetriever): """A retriver that looks for instance of objects in a database of images. You must provide a model in the `get_model()` method. You can call the `create_database()` method to create the database, and then query instances of objects using the `query()` method. """ def __init__(self, *args, storage='./databases', index='IndexFlatIP', **kwargs): """Initialize the retriever. Arguments: storage (str, optional): The path to the directory where to store the data. index (str, optional): The index of FAISS to use to store the embeddings. The default one is the FlatIP (Inner Product) that performs the cosine distance so your embeddings must be normalized beforehand. You can find more indexes here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes The rest of the parameters are the same as InstanceRetriever. """ self.database = None self.storage = storage self.embeddings_file = os.path.join(self.storage, 'embeddings.index') self.boxes_file = os.path.join(self.storage, 'boxes.index') self.index = index self.embeddings = None # FAISS index for the embeddings self.boxes = None # FAISS index for the boxes self.paths = {} # A dict to map between FAISS ids and images' paths ############################## ### SETTERS #### ############################## def _set_indexes(self, dim): """Set the FAISS index. The embedding could have any size but the bounding boxes must have size 4. Arguments: dim (int): The dimension of the embeddings. """ if self.index == 'IndexFlatL2': self.embeddings = faiss.IndexFlatL2(dim) self.boxes = faiss.IndexFlatL2(4) elif self.index == 'IndexFlatIP': self.embeddings = faiss.IndexFlatIP(dim) self.boxes = faiss.IndexFlatIP(4) else: raise ValueError('Index "{}" not supported.'.format(self.index)) ###################################### ### DATABASE METHODS ### ###################################### def create_database(self, batch_size=8, num_workers=8): """Generates the database to insert in an index of FAISS. Arguments: batch_size (int): The batch size to use to compute in parallel the images. num_workers (int): The number of process to use to load the images and generate the batches. """ self.print('Creating database ...') dataloader = self.dataset.get_dataloader(batch_size, num_workers) num_batches = len(dataloader) total_embs = 0 total_imgs = 0 init = time.time() with torch.no_grad(): for i, (images, paths) in enumerate(dataloader): embeddings, boxes = self.model(images) # Create the indexes if they are not created yet if self.embeddings is None or self.boxes is None: self._set_indexes(embeddings.shape[1]) # Add the vectors to the indexes self.embeddings.add(embeddings) self.boxes.add(boxes) # Map the id of the vectors to their image path for j, path in paths: self.paths[(i*batch_size) + j] = path # Show some stats about the progress total_imgs += images.shape[0] total_embs += embeddings.shape[0] self.logger.log({ 'Batch': '{}/{}'.format(i + 1, num_batches), 'Time': '{:.3f} s'.format(time.time() - init), 'Images': total_imgs, 'Embeddings': total_embs, }) self.save() def query(self, images, boxes=None, strategy='max_iou', k=100): """TODO:""" raise NotImplementedError() ################################### ### SAVING/LOADING ### ################################### def save(self): """Save the indexes in the storage directory.""" self.print('Saving indexes ...') faiss.write_index(self.embeddings, self.embeddings_file) faiss.write_index(self.boxes, self.boxes_file) def load(self): """Load the indexes from the storage directory.""" self.print('Loading indexes ...') if not os.path.exists(self.embeddings_file): raise ValueError('There is no ') self.embeddings = faiss.read_index(self.embeddings_file) self.boxes = faiss.read_index(self.boxes_file)
Methods
def __init__(self, *args, storage='./databases', index='IndexFlatIP', **kwargs)
-
Initialize the retriever.
Arguments
storage
:str
, optional- The path to the directory where to store the data.
index
:str
, optional- The index of FAISS to use to store the embeddings. The default one is the FlatIP (Inner Product) that performs the cosine distance so your embeddings must be normalized beforehand. You can find more indexes here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes
The rest of the parameters are the same as InstanceRetriever.
Source code
def __init__(self, *args, storage='./databases', index='IndexFlatIP', **kwargs): """Initialize the retriever. Arguments: storage (str, optional): The path to the directory where to store the data. index (str, optional): The index of FAISS to use to store the embeddings. The default one is the FlatIP (Inner Product) that performs the cosine distance so your embeddings must be normalized beforehand. You can find more indexes here: https://github.com/facebookresearch/faiss/wiki/Faiss-indexes The rest of the parameters are the same as InstanceRetriever. """ self.database = None self.storage = storage self.embeddings_file = os.path.join(self.storage, 'embeddings.index') self.boxes_file = os.path.join(self.storage, 'boxes.index') self.index = index self.embeddings = None # FAISS index for the embeddings self.boxes = None # FAISS index for the boxes self.paths = {} # A dict to map between FAISS ids and images' paths
def create_database(self, batch_size=8, num_workers=8)
-
Generates the database to insert in an index of FAISS.
Arguments
batch_size
:int
- The batch size to use to compute in parallel the images.
num_workers
:int
- The number of process to use to load the images and generate the batches.
Source code
def create_database(self, batch_size=8, num_workers=8): """Generates the database to insert in an index of FAISS. Arguments: batch_size (int): The batch size to use to compute in parallel the images. num_workers (int): The number of process to use to load the images and generate the batches. """ self.print('Creating database ...') dataloader = self.dataset.get_dataloader(batch_size, num_workers) num_batches = len(dataloader) total_embs = 0 total_imgs = 0 init = time.time() with torch.no_grad(): for i, (images, paths) in enumerate(dataloader): embeddings, boxes = self.model(images) # Create the indexes if they are not created yet if self.embeddings is None or self.boxes is None: self._set_indexes(embeddings.shape[1]) # Add the vectors to the indexes self.embeddings.add(embeddings) self.boxes.add(boxes) # Map the id of the vectors to their image path for j, path in paths: self.paths[(i*batch_size) + j] = path # Show some stats about the progress total_imgs += images.shape[0] total_embs += embeddings.shape[0] self.logger.log({ 'Batch': '{}/{}'.format(i + 1, num_batches), 'Time': '{:.3f} s'.format(time.time() - init), 'Images': total_imgs, 'Embeddings': total_embs, }) self.save()
def load(self)
-
Load the indexes from the storage directory.
Source code
def load(self): """Load the indexes from the storage directory.""" self.print('Loading indexes ...') if not os.path.exists(self.embeddings_file): raise ValueError('There is no ') self.embeddings = faiss.read_index(self.embeddings_file) self.boxes = faiss.read_index(self.boxes_file)
def query(self, images, boxes=None, strategy='max_iou', k=100)
-
TODO:
Source code
def query(self, images, boxes=None, strategy='max_iou', k=100): """TODO:""" raise NotImplementedError()
def save(self)
-
Save the indexes in the storage directory.
Source code
def save(self): """Save the indexes in the storage directory.""" self.print('Saving indexes ...') faiss.write_index(self.embeddings, self.embeddings_file) faiss.write_index(self.boxes, self.boxes_file)
Inherited members