Module datatap.torch

The torch module provides utilities for using dataTap with PyTorch.

Please note that if you want to be able to use this module, you will either need to install PyTorch manually, or install dataTap with the PyTorch extra:

pip install 'datatap[torch]'

The torch module provides both a torch.IterableDataset implementation, and a convenience method to create a torch.Dataloader using it. Here is an example of how to use these:

import itertools
from datatap import Api
from datatap.torch import create_dataloader

import torchvision.transforms as T

api = Api()
dataset = api.get_default_database().get_dataset_list()[0]
latest_version = dataset.latest_version

transforms = T.Compose([
    T.Resize((128, 128)),
    T.ColorJitter(hue=0.2),
    T.ToTensor(),
])

dataloader = create_dataloader(latest_version, "training", batch_size = 4, image_transform = transforms)
for batch in itertools.islice(dataloader, 3):
    print(batch.boxes, batch.labels)
Expand source code
"""
The `torch` module provides utilities for using dataTap with PyTorch.

Please note that if you want to be able to use this module, you will
either need to install PyTorch manually, or install dataTap with the
PyTorch extra:

```bash
pip install 'datatap[torch]'
```

The `torch` module provides both a `torch.IterableDataset` implementation,
and a convenience method to create a `torch.Dataloader` using it. Here is
an example of how to use these:

```py
import itertools
from datatap import Api
from datatap.torch import create_dataloader

import torchvision.transforms as T

api = Api()
dataset = api.get_default_database().get_dataset_list()[0]
latest_version = dataset.latest_version

transforms = T.Compose([
    T.Resize((128, 128)),
    T.ColorJitter(hue=0.2),
    T.ToTensor(),
])

dataloader = create_dataloader(latest_version, "training", batch_size = 4, image_transform = transforms)
for batch in itertools.islice(dataloader, 3):
    print(batch.boxes, batch.labels)
```

"""

from ._patch_torch import patch_all as _patch_all
_patch_all()

from .dataset import DatasetElement, DatasetBatch, IterableDataset
from .dataloader import create_dataloader
from .utils import torch_to_image_annotation

__all__ = [
    "DatasetElement",
    "DatasetBatch",
    "IterableDataset",
    "create_dataloader",
    "torch_to_image_annotation",
]

Sub-modules

datatap.torch.dataloader
datatap.torch.dataset
datatap.torch.utils

Functions

def create_dataloader(dataset: Dataset, split: str, batch_size: int = 1, num_workers: int = 24, *, image_transform: Callable[[PIL.Image.Image], torch.Tensor] = <function to_tensor>, class_mapping: Optional[Dict[str, int]] = None, device: torch.device = device(type='cpu')) ‑> torch.utils.data.dataloader.DataLoader

Creates a PyTorch Dataloader that yields batches of annotations.

This Dataloader is using datatap.torch.Dataset under the hood, so all of the same restrictions apply, most notably that the image_transform function must ultimately return a torch.Tensor of dimensionality (…, H, W).

Expand source code
def create_dataloader(
    dataset: Dataset,
    split: str,
    batch_size: int = 1,
    num_workers: int = cpu_count() or 0,
    *,
    image_transform: Callable[[PIL.Image.Image], torch.Tensor] = TF.to_tensor,
    class_mapping: Optional[Dict[str, int]] = None,
    device: torch.device = torch.device("cpu")
) -> DataLoader[DatasetBatch]:
    """
    Creates a PyTorch `Dataloader` that yields batches of annotations.

    This `Dataloader` is using `datatap.torch.Dataset` under the hood, so
    all of the same restrictions apply, most notably that the `image_transform`
    function must ultimately return a `torch.Tensor` of dimensionality
    `(..., H, W)`.
    """
    if torch.multiprocessing.get_start_method(allow_none = True) is None:
        torch.multiprocessing.set_start_method("spawn")

    torch_dataset = IterableDataset(dataset, split, image_transform = image_transform, class_mapping = class_mapping, device = device)
    dataloader = cast(
        DataLoader[DatasetBatch],
        DataLoader(
            torch_dataset,
            batch_size,
            collate_fn = collate, # type: ignore (Torch's types are off)
            num_workers = num_workers,
        )
    )

    return dataloader
def torch_to_image_annotation(image: torch.Tensor, class_map: Dict[str, int], *, labels: torch.Tensor, boxes: torch.Tensor, scores: torch.Tensor, serialize_image: bool = False, uid: Optional[str] = None) ‑> ImageAnnotation

Creates an ImageAnnotation from a canonical tensor representation.

This function assumes the following,

  1. Image is of dimensionality (…, height, width)
  2. Labels are an int/uint tensor of size [n]
  3. Scores are a float tensor of size [n]
  4. Boxes are a float tensor of size [n, 4]
Expand source code
def torch_to_image_annotation(
    image: torch.Tensor,
    class_map: Dict[str, int],
    *,
    labels: torch.Tensor,
    boxes: torch.Tensor,
    scores: torch.Tensor,
    serialize_image: bool = False,
    uid: Optional[str] = None,
) -> ImageAnnotation:
    """
    Creates an `ImageAnnotation` from a canonical tensor representation.

    This function assumes the following,

    1. Image is of dimensionality `(..., height, width)`
    2. Labels are an `int`/`uint` tensor of size `[n]`
    3. Scores are a `float` tensor of size `[n]`
    3. Boxes are a `float` tensor of size `[n, 4]`
    """
    inverted_class_map = {
        i: cls
        for cls, i in class_map.items()
    }

    height, width = image.shape[-2:]

    # First construct the image. If we are asked to serialize it, then
    # use the tensor to construct a cached PIL image
    if serialize_image:
        pil_image = TF.to_pil_image(image, "RGB")
        droplet_image = Image.from_pil(pil_image)
    else:
        droplet_image = Image(paths = [])

    # Then, compute each of the class annotations
    class_annotations: Dict[str, List[Instance]] = {}

    boxes = boxes.cpu() / torch.tensor([width, height, width, height])

    for i, label in enumerate(labels.cpu()):
        class_name = inverted_class_map.get(int(label))
        if class_name is None:
            continue

        if class_name not in class_annotations:
            class_annotations[class_name] = []

        class_annotations[class_name].append(
            Instance(
                bounding_box = BoundingBox(
                    tensor_to_rectangle(boxes[i]),
                    confidence = float(scores[i]),
                )
            )
        )

    # Finally, construct the image annotation

    return ImageAnnotation(
        uid = uid,
        image = droplet_image,
        classes = {
            cls: ClassAnnotation(instances = instances, multi_instances = [])
            for cls, instances in class_annotations.items()
        }
    )

Classes

class DatasetBatch (original_annotations: List[ImageAnnotation], images: List[torch.Tensor], boxes: List[torch.Tensor], labels: List[torch.Tensor])

Represents a batch of images as produced by a DataLoader.

Expand source code
class DatasetBatch():
    """
    Represents a batch of images as produced by a `DataLoader`.
    """

    original_annotations: List[ImageAnnotation]
    """
    The original annotations from this batch.
    """

    images: List[torch.Tensor]
    """
    A list of the images in this batch.
    """

    boxes: List[torch.Tensor]
    """
    A list of all the per-image bounding boxes in this batch.
    """

    labels: List[torch.Tensor]
    """
    A list of all the per-image labels in this batch.
    """

    def __init__(self, original_annotations: List[ImageAnnotation], images: List[torch.Tensor], boxes: List[torch.Tensor], labels: List[torch.Tensor]):
        self.original_annotations = original_annotations
        self.images = images
        self.boxes = boxes
        self.labels = labels

Class variables

var boxes : List[torch.Tensor]

A list of all the per-image bounding boxes in this batch.

var images : List[torch.Tensor]

A list of the images in this batch.

var labels : List[torch.Tensor]

A list of all the per-image labels in this batch.

var original_annotations : List[ImageAnnotation]

The original annotations from this batch.

class DatasetElement (original_annotation: ImageAnnotation, image: torch.Tensor, boxes: torch.Tensor, labels: torch.Tensor)

Represents a single element from the dataset.

Expand source code
class DatasetElement():
    """
    Represents a single element from the dataset.
    """

    original_annotation: ImageAnnotation
    """
    The original, untransformed annotation.
    """

    image: torch.Tensor
    """
    The image as transformed by the dataset.
    """

    boxes: torch.Tensor
    """
    The bounding boxes. They are specified in xyxy format `(min-x, min-y, max-x, max-y)`.
    """

    labels: torch.Tensor
    """
    The labels. They are a tensor of unsigned integers.
    """

    def __init__(self, original_annotation: ImageAnnotation, image: torch.Tensor, boxes: torch.Tensor, labels: torch.Tensor):
        self.original_annotation = original_annotation
        self.image = image
        self.boxes = boxes
        self.labels = labels

Class variables

var boxes : torch.Tensor

The bounding boxes. They are specified in xyxy format (min-x, min-y, max-x, max-y).

var image : torch.Tensor

The image as transformed by the dataset.

var labels : torch.Tensor

The labels. They are a tensor of unsigned integers.

var original_annotationImageAnnotation

The original, untransformed annotation.

class IterableDataset (dataset: Dataset, split: str, class_mapping: Optional[Dict[str, int]] = None, image_transform: Callable[[PIL.Image.Image], torch.Tensor] = <function to_tensor>, device: torch.device = device(type='cpu'))

A PyTorch IterableDataset that yields all of the annotations from a given DatasetVersion. Provides functionality for automatically applying transforms to images, and then scaling the annotations to the new dimensions.

Note, it is required that the transformation produce a image tensor of dimensionality […, H, W]. One way of doing this is using torchvision.transforms.functional.to_tensor as the final step of the transform.

Expand source code
class IterableDataset(TorchIterableDataset[DatasetElement]):
    """
    A PyTorch `IterableDataset` that yields all of the annotations from a
    given `DatasetVersion`. Provides functionality for automatically applying
    transforms to images, and then scaling the annotations to the new dimensions.

    Note, it is required that the transformation produce a image tensor of
    dimensionality `[..., H, W]`. One way of doing this is using
    `torchvision.transforms.functional.to_tensor` as the final step of the transform.
    """

    _dataset: Dataset
    _split: str
    _class_mapping: Dict[str, int]
    _class_names: Dict[int, str]
    _device: torch.device

    def __init__(
        self,
        dataset: Dataset,
        split: str,
        class_mapping: Optional[Dict[str, int]] = None,
        image_transform: Callable[[PIL.Image.Image], torch.Tensor] = TF.to_tensor,
        device: torch.device = torch.device("cpu")
    ):
        self._dataset = dataset
        self._split = split
        self._image_transform = image_transform
        self._device = device

        template_classes = dataset.template.classes.keys()
        if class_mapping is not None:
            if set(class_mapping.keys()) != set(template_classes):
                print(
                    "[WARNING]: Potentially invalid class mapping. Provided classes ",
                    set(class_mapping.keys()),
                    " but needed ",
                    set(template_classes)
                )
            self._class_mapping = class_mapping
        else:
            self._class_mapping = {
                cls: i
                for i, cls in enumerate(sorted(template_classes))
            }

        self._class_names = {
            i: cls
            for cls, i in self._class_mapping.items()
        }

    def _get_generator(self):
        worker_info: Optional[Any] = get_worker_info()

        if worker_info is None:
            return self._dataset.stream_split(self._split, 0, 1)
        else:
            num_workers: int = worker_info.num_workers
            worker_id: int = worker_info.id

            return self._dataset.stream_split(self._split, worker_id, num_workers)

    def __iter__(self) -> Iterator[DatasetElement]:
        for annotation in self._get_generator():
            img = annotation.image.get_pil_image(True).convert("RGB")
            transformed_img = self._image_transform(img).to(self._device)
            h, w = transformed_img.shape[-2:]

            instance_boxes = [
                (
                    instance.bounding_box.rectangle.p1.x * w,
                    instance.bounding_box.rectangle.p1.y * h,
                    instance.bounding_box.rectangle.p2.x * w,
                    instance.bounding_box.rectangle.p2.y * h,
                )
                for class_name in annotation.classes.keys()
                for instance in annotation.classes[class_name].instances
                if instance.bounding_box is not None
            ]

            instance_labels = [
                self._class_mapping[class_name]
                for class_name in annotation.classes.keys()
                for _ in annotation.classes[class_name].instances
                if class_name in self._class_mapping
            ]

            target = torch.tensor(instance_boxes).reshape((-1, 4)).to(self._device)
            labels = torch.tensor(instance_labels, dtype = torch.int64).to(self._device)

            element = DatasetElement(annotation, transformed_img, target, labels)

            yield element

Ancestors

  • abc.IterableDataset[datatap.torch.dataset.DatasetElement]
  • torch.utils.data.dataset.IterableDataset
  • torch.utils.data.dataset.Dataset
  • typing.Generic