torchsight.experiments.retrievers.sanity
module
A sanity check of the instance retriever.
We are going to look for an object of an image between 3 images (included itself!) so any decent model must get the real object.
Source code
"""A sanity check of the instance retriever.
We are going to look for an object of an image between 3 images (included itself!)
so any decent model must get the real object.
"""
import os
import torch
from PIL import Image
from torchsight.retrievers.resnet import ResnetRetriever
def main():
query_boxes = torch.Tensor([[[269, 160, 168, 251]],
[[47, 53, 399, 249]]]) # (2, 1, 4)
x1, y1, w, h = query_boxes[:, :, 0], query_boxes[:, :, 1], query_boxes[:, :, 2], query_boxes[:, :, 3]
x2, y2 = x1 + w, y1 + h
query_boxes = torch.stack([x1, y1, x2, y2], dim=2)
root = '/home/souto/datasets/flickr32/sanity_check/'
images = [Image.open(os.path.join(root, image)) for image in ['apple.jpg', 'adidas.jpg']]
retriever = ResnetRetriever(root=root)
distances, boxes, paths, _ = retriever.query(images, query_boxes, k=5, device='cpu')
for i, query_image in enumerate(images):
query_image = retriever.image_transform({'image': query_image})
box_with_dist = torch.zeros((1, 5))
box_with_dist[:, :4] = torch.Tensor(query_boxes[i])
box_with_dist[:, 4] = 0
retriever.visualize(query_image, distances[i], boxes[i], paths[i], query_box=box_with_dist)
if __name__ == '__main__':
main()
Functions
def main()
-
Source code
def main(): query_boxes = torch.Tensor([[[269, 160, 168, 251]], [[47, 53, 399, 249]]]) # (2, 1, 4) x1, y1, w, h = query_boxes[:, :, 0], query_boxes[:, :, 1], query_boxes[:, :, 2], query_boxes[:, :, 3] x2, y2 = x1 + w, y1 + h query_boxes = torch.stack([x1, y1, x2, y2], dim=2) root = '/home/souto/datasets/flickr32/sanity_check/' images = [Image.open(os.path.join(root, image)) for image in ['apple.jpg', 'adidas.jpg']] retriever = ResnetRetriever(root=root) distances, boxes, paths, _ = retriever.query(images, query_boxes, k=5, device='cpu') for i, query_image in enumerate(images): query_image = retriever.image_transform({'image': query_image}) box_with_dist = torch.zeros((1, 5)) box_with_dist[:, :4] = torch.Tensor(query_boxes[i]) box_with_dist[:, 4] = 0 retriever.visualize(query_image, distances[i], boxes[i], paths[i], query_box=box_with_dist)