Skip to content
Snippets Groups Projects

Implementation of Deep Learning unit tests, as well as paths to the 2d data for windows users in the UNet jupyter notebook.

1 file
+ 3
11
Compare changes
  • Side-by-side
  • Inline
+ 10
5
@@ -10,7 +10,7 @@ from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True, return_loss = False):
""" Function for training Neural Network models.
Args:
@@ -21,6 +21,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
print_every (int, optional): frequency of log for model performance. Defaults to every 5 epochs.
Returns:
tuple:
train_loss (dict): dictionary with average losses and batch losses for training loop.
@@ -65,10 +66,11 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in train_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# Backpropagation
@@ -94,7 +96,7 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
for data in val_loader:
inputs, targets = data
inputs = inputs.to(device)
targets = targets.to(device).type(torch.cuda.FloatTensor).unsqueeze(1)
targets = targets.to(device).unsqueeze(1)
with torch.no_grad():
outputs = model(inputs)
@@ -122,6 +124,9 @@ def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1
plot_metrics(val_loss,color = 'orange', label = 'Valid.')
fig.show()
if return_loss:
return train_loss,val_loss
def model_summary(dataloader,model):
"""Prints the summary of a PyTorch model.
@@ -196,7 +201,7 @@ def inference(data,model):
else:
raise ValueError("Input image must be (C,H,W) format")
model.to(device)
model.eval()
# Make new list such that possible augmentations remain identical for all three rows
Loading