Skip to content
Snippets Groups Projects
Commit d9b70f14 authored by fima's avatar fima :beers:
Browse files

Merge branch 'data' into 'main'

First version of dataset class

See merge request !3
parents 73ad3022 13ce0c30
No related branches found
No related tags found
1 merge request!3First version of dataset class
%% Cell type:code id:dd6781ce tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id:fa88080a tags:
``` python
from glob import glob
from os.path import join
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from monai.networks.nets import UNet
from torchvision import transforms
from monai.losses import FocalLoss, DiceLoss
import qim3d
import albumentations as A
from albumentations.pytorch import ToTensorV2
%matplotlib inline
```
%% Cell type:code id:d0a5eade tags:
``` python
# Define function for getting dataset path from string
def get_dataset_path(name: str):
datasets = [
'belialev2020_side',
'gaudez2022_3d',
'guo2023_2d',
'stan2020_2d',
'reichardt2021_2d',
'testcircles_2dbinary',
]
assert name in datasets, 'Dataset name must be ' + ' or '.join(datasets)
dataset_idx = datasets.index(name)
datasets_path = [
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Belialev2020/side',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Gaudez2022/3d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Guo2023/2d/',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Stan2020/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/Reichardt2021/2d',
'/dtu/3d-imaging-center/projects/2023_STUDIOS_SD/analysis/data/TestCircles/2d_binary'
]
return datasets_path[dataset_idx]
```
%% Cell type:markdown id:ee235f48 tags:
# Data loading
%% Cell type:markdown id:c088ceb8 tags:
### Check out https://albumentations.ai/docs/getting_started/transforms_and_targets/
%% Cell type:code id:ddfef29e tags:
``` python
# Training set transformation
aug_train = A.Compose([
A.Resize(832,832),
A.RandomRotate90(),
A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]
ToTensorV2()
])
# Validation/test set transformation
aug_val_test = A.Compose([
A.Resize(832,832),
A.Normalize(mean=(0.5),std=(0.5)), # Normalize to [-1, 1]
ToTensorV2()
])
```
%% Cell type:code id:766f0f8e tags:
``` python
### Possible datasets ####
# 'belialev2020_side'
# 'gaudez2022_3d'
# 'guo2023_2d'
# 'stan2020_2d'
# 'reichardt2021_2d'
# 'testcircles_2dbinary'
# Choose dataset
dataset = 'stan2020_2d'
# Define class instances. First, both train and validation set is defined from train
# folder with different transformations and below divided into non-overlapping subsets
train_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_train)
val_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),transform=aug_val_test)
test_set = qim3d.qim3d.utils.Dataset(root_path=get_dataset_path(dataset),split='test',transform=aug_val_test)
# Define fraction of training set used for validation
VAL_FRACTION = 0.3
split_idx = int(np.floor(VAL_FRACTION * len(train_set)))
# Define seed
# torch.manual_seed(42)
# Get randomly permuted indices
indices = torch.randperm(len(train_set))
# Define train and validation sets as subsets
train_set = torch.utils.data.Subset(train_set, indices[split_idx:])
val_set = torch.utils.data.Subset(val_set, indices[:split_idx])
```
%% Cell type:markdown id:321495cc tags:
### Data overview
%% Cell type:code id:a794b739 tags:
``` python
# Check if data has mask
has_mask= False #True if train_set[0][-1] is not None else False
print(f'No. of train images={len(train_set)}')
print(f'No. of validation images={len(val_set)}')
print(f'No. of test images={len(test_set)}')
print(f'{train_set[0][0].dtype=}')
print(f'{train_set[0][1].dtype=}')
print(f'image shape={train_set[0][0].shape}')
print(f'label shape={train_set[0][1].shape}')
print(f'Labels={np.unique(train_set[0][1])}')
print(f'Masked data? {has_mask}')
```
%% Cell type:markdown id:5efa7d33 tags:
### Data visualization
Display first seven image, labels, and masks if they exist
%% Cell type:code id:170577d3 tags:
``` python
qim3d.qim3d.viz.grid_overview(train_set,num_images=6,alpha=1)
```
%% Cell type:code id:33368063 tags:
``` python
# Define batch sizes
TRAIN_BATCH_SIZE = 4
VAL_BATCH_SIZE = 4
TEST_BATCH_SIZE = 4
# Define dataloaders
train_loader = DataLoader(dataset=train_set, batch_size=TRAIN_BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset=val_set, batch_size=VAL_BATCH_SIZE, num_workers=8, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=TEST_BATCH_SIZE, num_workers=8, pin_memory=True)
```
%% Cell type:markdown id:35e83e38 tags:
# Train U-Net
%% Cell type:code id:36685b25 tags:
``` python
# Define model
model = UNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(64, 128, 256, 512, 1024),
strides=(2, 2, 2, 2),
)
orig_state = model.state_dict() # Save, so we can reset model to original state later
# Define loss function
#loss_fn = nn.CrossEntropyLoss()
loss_fn = FocalLoss()
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
```
%% Cell type:markdown id:137be29b tags:
### Run training
%% Cell type:code id:13d8a9f3 tags:
``` python
# Define hyperparameters
NUM_EPOCHS = 5
EVAL_EVERY = 1
PRINT_EVERY = 1
LR = 3e-3
model.load_state_dict(orig_state) # Restart training every time
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
all_losses = []
all_val_loss = []
for epoch in range(NUM_EPOCHS):
model.train()
epoch_loss = 0
step = 0
for data in train_loader:
if has_mask:
inputs, targets, masks = data
masks = masks.to(device).float()
else:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).float().unsqueeze(1)
# Forward -> Backward -> Step
optimizer.zero_grad()
outputs = model(inputs)
#print(f'input {inputs.shape}, target: {targets.shape}, output: {outputs.shape}')
loss = loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)
loss.backward()
optimizer.step()
epoch_loss += loss.detach()
step += 1
# Log and store average epoch loss
epoch_loss = epoch_loss.item() / step
all_losses.append(epoch_loss)
if epoch % EVAL_EVERY == 0:
model.eval()
with torch.no_grad(): # Do not need gradients for this part
loss_sum = 0
step = 0
for data in val_loader:
if has_mask:
inputs, targets, masks = data
masks = masks.to(device).float()
else:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).float().unsqueeze(1)
outputs = model(inputs)
loss_sum += loss_fn(outputs*masks, targets*masks) if has_mask else loss_fn(outputs, targets)
step += 1
val_loss = loss_sum.item() / step
all_val_loss.append(val_loss)
# Log and store average accuracy
if epoch % PRINT_EVERY == 0:
print(f'Epoch {epoch: 3}, train loss: {epoch_loss:.4f}, val loss: {val_loss:.4f}')
print('Min val loss:', min(all_val_loss))
```
%% Cell type:markdown id:a7a8e9d7 tags:
### Plot train and validation loss
%% Cell type:code id:851463c8 tags:
``` python
plt.figure(figsize=(16, 3))
plt.plot(all_losses, '-', label='Train')
plt.plot(all_val_loss, '-', label='Val.')
plt.legend()
plt.show()
```
%% Cell type:markdown id:1a700f8a tags:
### Inspecting the Predicted Segmentations on training data
%% Cell type:code id:2ac83638 tags:
``` python
qim3d.qim3d.viz.grid_pred(train_set,model,num_images=5,alpha=1)
```
%% Cell type:markdown id:a176ff96 tags:
### Inspecting the Predicted Segmentations on test data
%% Cell type:code id:ffb261c2 tags:
``` python
qim3d.qim3d.viz.grid_pred(test_set,model,alpha=1)
```
import qim3d.io
import qim3d.gui
import qim3d.tools
import qim3d.utils
import qim3d.viz
import logging
\ No newline at end of file
import gradio as gr
import numpy as np
import os
from qim3d.tools import internal_tools
from qim3d.utils import internal_tools
from qim3d.io import DataLoader
from qim3d.io.logger import log
import tifffile
......
import gradio as gr
import numpy as np
import os
from qim3d.tools import internal_tools
from qim3d.utils import internal_tools
from qim3d.io import DataLoader
from qim3d.io.logger import log
import plotly.graph_objects as go
......@@ -44,7 +44,6 @@ class Interface:
return None
def load_data(self, filepath):
# TODO: Add support for multiple files
self.vol = DataLoader().load_tiff(filepath)
......
import gradio as gr
import numpy as np
import os
from qim3d.tools import internal_tools
from qim3d.utils import internal_tools
from qim3d.io import DataLoader
from qim3d.io.logger import log
import tifffile
......
......@@ -6,7 +6,7 @@ import difflib
import tifffile
import h5py
from qim3d.io.logger import log
from qim3d.tools.internal_tools import sizeof
from qim3d.utils.internal_tools import sizeof
class DataLoader:
......
from . import *
from . import internal_tools
from .data import Dataset
\ No newline at end of file
"""Provides a custom Dataset class for building a PyTorch dataset"""
from pathlib import Path
from PIL import Image
import torch
import numpy as np
class Dataset(torch.utils.data.Dataset):
"""
Custom Dataset class for building a PyTorch dataset
Args:
root_path (str): The root directory path of the dataset.
split (str, optional): The split of the dataset, either "train" or "test".
Default is "train".
transform (callable, optional): A callable function or transformation to
be applied to the data. Default is None.
Raises:
ValueError: If the provided split is not valid (neither "train" nor "test").
Attributes:
split (str): The split of the dataset ("train" or "test").
transform (callable): The transformation to be applied to the data.
sample_images (list): A list containing the paths to the sample images in the dataset.
sample_targets (list): A list containing the paths to the corresponding target images
in the dataset.
Methods:
__len__(): Returns the total number of samples in the dataset.
__getitem__(idx): Returns the image and its target segmentation at the given index.
Usage:
dataset = Dataset(root_path="path/to/dataset", split="train",
transform=albumentations.Compose([ToTensorV2()]))
image, target = dataset[idx]
"""
def __init__(self, root_path: str, split="train", transform=None):
super().__init__()
# Check if split is valid
if split not in ["train", "test"]:
raise ValueError("Split must be either train or test")
self.split = split
self.transform = transform
path = Path(root_path) / split
self.sample_images = [file for file in sorted((path / "images").iterdir())]
self.sample_targets = [file for file in sorted((path / "labels").iterdir())]
assert len(self.sample_images) == len(self.sample_targets)
def __len__(self):
return len(self.sample_images)
def __getitem__(self, idx):
image_path = self.sample_images[idx]
target_path = self.sample_targets[idx]
image = Image.open(str(image_path))
image = np.array(image)
target = Image.open(str(target_path))
target = np.array(target)
if self.transform:
transformed = self.transform(image=image, mask=target)
image = transformed["image"]
target = transformed["mask"]
return image, target
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment