Skip to content
Snippets Groups Projects
__init__.py 2.08 KiB
Newer Older
from os.path import join, expanduser, basename, exists, splitext
Bobholamovic's avatar
Bobholamovic committed

import torch
import torch.utils.data as data
import numpy as np

from .common import (default_loader, to_tensor)


class CDDataset(data.Dataset):
    def __init__(
        self, 
        root, phase,
        transforms,
        repeats
    ):
        super().__init__()
        self.root = expanduser(root)
Bobholamovic's avatar
Bobholamovic committed
        if not exists(self.root):
            raise FileNotFoundError
Bobholamovic's avatar
Bobholamovic committed
        self.phase = phase
Bobholamovic's avatar
Bobholamovic committed
        self.transforms = list(transforms)
        self.transforms += [None]*(3-len(self.transforms))
        self.repeats = int(repeats)
Bobholamovic's avatar
Bobholamovic committed

        self.t1_list, self.t2_list, self.label_list = self._read_file_paths()
        self.len = len(self.label_list)

    def __len__(self):
        return self.len * self.repeats

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError
        index = index % self.len
        
        t1 = self.fetch_image(self.t1_list[index])
        t2 = self.fetch_image(self.t2_list[index])
        label = self.fetch_label(self.label_list[index])
        t1, t2, label = self.preprocess(t1, t2, label)
        if self.phase == 'train':
            return t1, t2, label
        else:
            return self.get_name(index), t1, t2, label
Bobholamovic's avatar
Bobholamovic committed

    def _read_file_paths(self):
        raise NotImplementedError
        
    def fetch_label(self, label_path):
        return default_loader(label_path)

    def fetch_image(self, image_path):
        return default_loader(image_path)

    def get_name(self, index):
        return splitext(basename(self.label_list[index]))[0]+'.bmp'

Bobholamovic's avatar
Bobholamovic committed
    def preprocess(self, t1, t2, label):
        if self.transforms[0] is not None:
            # Applied on all
            t1, t2, label = self.transforms[0](t1, t2, label)
        if self.transforms[1] is not None:
            # For images solely
            t1, t2 = self.transforms[1](t1, t2)
        if self.transforms[2] is not None:
            # For labels solely
            label = self.transforms[2](label)
        
        return to_tensor(t1).float(), to_tensor(t2).float(), to_tensor(label).long()