Module datatap.torch.dataloader
Expand source code
from __future__ import annotations
from os import cpu_count
from typing import Callable, Dict, Generator, Optional, TypeVar, cast, TYPE_CHECKING
import torch
import PIL.Image
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader as TorchDataLoader
from datatap.api.entities import Dataset
from .dataset import IterableDataset, DatasetBatch, collate
_T = TypeVar("_T")
if TYPE_CHECKING:
class DataLoader(TorchDataLoader[_T]):
"""
This is an ambient redeclaration of the dataloader class that
has properly typed iter methods.
"""
def __iter__(self) -> Generator[_T, None, None]: ...
else:
DataLoader = TorchDataLoader
def create_dataloader(
dataset: Dataset,
split: str,
batch_size: int = 1,
num_workers: int = cpu_count() or 0,
*,
image_transform: Callable[[PIL.Image.Image], torch.Tensor] = TF.to_tensor,
class_mapping: Optional[Dict[str, int]] = None,
device: torch.device = torch.device("cpu")
) -> DataLoader[DatasetBatch]:
"""
Creates a PyTorch `Dataloader` that yields batches of annotations.
This `Dataloader` is using `datatap.torch.Dataset` under the hood, so
all of the same restrictions apply, most notably that the `image_transform`
function must ultimately return a `torch.Tensor` of dimensionality
`(..., H, W)`.
"""
if torch.multiprocessing.get_start_method(allow_none = True) is None:
torch.multiprocessing.set_start_method("spawn")
torch_dataset = IterableDataset(dataset, split, image_transform = image_transform, class_mapping = class_mapping, device = device)
dataloader = cast(
DataLoader[DatasetBatch],
DataLoader(
torch_dataset,
batch_size,
collate_fn = collate, # type: ignore (Torch's types are off)
num_workers = num_workers,
)
)
return dataloader
Functions
def create_dataloader(dataset: Dataset, split: str, batch_size: int = 1, num_workers: int = 24, *, image_transform: Callable[[PIL.Image.Image], torch.Tensor] = <function to_tensor>, class_mapping: Optional[Dict[str, int]] = None, device: torch.device = device(type='cpu')) ‑> torch.utils.data.dataloader.DataLoader
-
Creates a PyTorch
Dataloader
that yields batches of annotations.This
Dataloader
is usingdatatap.torch.Dataset
under the hood, so all of the same restrictions apply, most notably that theimage_transform
function must ultimately return atorch.Tensor
of dimensionality(…, H, W)
.Expand source code
def create_dataloader( dataset: Dataset, split: str, batch_size: int = 1, num_workers: int = cpu_count() or 0, *, image_transform: Callable[[PIL.Image.Image], torch.Tensor] = TF.to_tensor, class_mapping: Optional[Dict[str, int]] = None, device: torch.device = torch.device("cpu") ) -> DataLoader[DatasetBatch]: """ Creates a PyTorch `Dataloader` that yields batches of annotations. This `Dataloader` is using `datatap.torch.Dataset` under the hood, so all of the same restrictions apply, most notably that the `image_transform` function must ultimately return a `torch.Tensor` of dimensionality `(..., H, W)`. """ if torch.multiprocessing.get_start_method(allow_none = True) is None: torch.multiprocessing.set_start_method("spawn") torch_dataset = IterableDataset(dataset, split, image_transform = image_transform, class_mapping = class_mapping, device = device) dataloader = cast( DataLoader[DatasetBatch], DataLoader( torch_dataset, batch_size, collate_fn = collate, # type: ignore (Torch's types are off) num_workers = num_workers, ) ) return dataloader