Skip to content
Snippets Groups Projects
data.py 2.17 KiB
Newer Older
  • Learn to ignore specific revisions
  • Bobholamovic's avatar
    Bobholamovic committed
    import os.path
    from abc import ABCMeta, abstractmethod
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.utils.data as data
    
    
    # Data builder utilities
    def build_train_dataloader(cls, configs, C):
        return data.DataLoader(
            cls(**configs),
            batch_size=C['batch_size'],
            shuffle=True,
            num_workers=C['num_workers'],
            pin_memory=C['device']!='cpu',
            drop_last=True
        )
    
    
    def build_eval_dataloader(cls, configs):
        return data.DataLoader(
            cls(**configs),
            batch_size=1,
            shuffle=False,
            num_workers=1,
            pin_memory=False,
            drop_last=False
        )
    
    
    def get_common_train_configs(C):
        return dict(phase='train', repeats=C['repeats'])
    
    
    def get_common_eval_configs(C):
        return dict(phase='eval', transforms=[None, None, None], subset=C['subset'])
    
    
    # Dataset prototype
    class DatasetBase(data.Dataset, metaclass=ABCMeta):
        def __init__(
            self, 
            root, phase,
            transforms,
            repeats, 
            subset
        ):
            super().__init__()
            self.root = os.path.expanduser(root)
            if not os.path.exists(self.root):
                raise FileNotFoundError
            # phase stands for the working mode,
            # 'train' for training and 'eval' for validating or testing.
    
    Bobholamovic's avatar
    Bobholamovic committed
            # if phase not in ('train', 'eval'):
            #     raise ValueError("Invalid phase")
    
    Bobholamovic's avatar
    Bobholamovic committed
            # subset is the sub-dataset to use.
            # For some datasets there are three subsets,
            # while for others there are only train and test(val).
    
    Bobholamovic's avatar
    Bobholamovic committed
            # if subset not in ('train', 'val', 'test'):
            #     raise ValueError("Invalid subset")
    
    Bobholamovic's avatar
    Bobholamovic committed
            self.phase = phase
            self.transforms = transforms
    
    Bobholamovic's avatar
    Bobholamovic committed
            self.repeats = repeats
    
    Bobholamovic's avatar
    Bobholamovic committed
            # Use 'train' subset during training.
            self.subset = 'train' if self.phase == 'train' else subset
    
        def __len__(self):
            return self.len * self.repeats
    
        def __getitem__(self, index):
            if index >= len(self):
                raise IndexError
            index = index % self.len
    
            item = self.fetch_and_preprocess(index)
    
            return item
    
        @abstractmethod
        def fetch_and_preprocess(self, index):
            return None