Skip to content
Snippets Groups Projects
Commit a616caca authored by OskarK's avatar OskarK
Browse files

Minor changes to docstrings & hyperparameters.

parent 9aedc0f5
No related branches found
No related tags found
No related merge requests found
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -27,8 +27,7 @@ class UNet(nn.Module):
ValueError: If `size` is not one of 'small', 'medium', or 'large'.
Example:
unet = qim_UNet(size='large')
model = unet()
model = UNet(size='large')
"""
def __init__(self, size = 'medium',
dropout = 0,
......@@ -103,15 +102,15 @@ class Hyperparameters:
Example:
# Create hyperparameters instance
hyperparams = qim_hyperparameters(model=my_model, n_epochs=20, learning_rate=0.001)
hyperparams = Hyperparameters(model=my_model, n_epochs=20, learning_rate=0.001)
# Get the hyperparameters
params = hyperparams()
params_dict = hyperparams()
# Access the optimizer and criterion
optimizer = params['optimizer']
criterion = params['criterion']
n_epochs = params['n_epochs']
optimizer = params_dict['optimizer']
criterion = params_dict['criterion']
n_epochs = params_dict['n_epochs']
"""
def __init__(self,
model,
......
......@@ -10,12 +10,12 @@ from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
def train_model(model, qim_hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
def train_model(model, hyperparameters, train_loader, val_loader, eval_every = 1, print_every = 5, plot = True):
""" Function for training Neural Network models.
Args:
model (torch.nn.Module): PyTorch model.
qim_hyperparameters (dict): dictionary with n_epochs, optimizer and criterion.
hyperparameters (class): dictionary with n_epochs, optimizer and criterion.
train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
val_loader (torch.utils.data.DataLoader): DataLoader for the validation data.
eval_every (int, optional): frequency of model evaluation. Defaults to every epoch.
......@@ -28,23 +28,22 @@ def train_model(model, qim_hyperparameters, train_loader, val_loader, eval_every
Example:
# defining the model.
model = qim3d.qim3d.utils.qim_UNet()
model = qim3d.utils.UNet()
# choosing the hyperparameters
qim_hyper = qim3d.qim3d.utils.qim_hyperparameters(model)
hyper_dict = qim_hyper()
hyperparameters = qim3d.utils.hyperparameters(model)
# DataLoaders
train_loader = MyTrainLoader()
val_loader = MyValLoader()
# training the model.
train_loss,val_loss = train_model(model, hyper_dict, train_loader, val_loader)
train_loss,val_loss = train_model(model, hyperparameters, train_loader, val_loader)
"""
n_epochs = qim_hyperparameters['n_epochs']
optimizer = qim_hyperparameters['optimizer']
criterion = qim_hyperparameters['criterion']
params_dict = hyperparameters()
n_epochs = params_dict['n_epochs']
optimizer = params_dict['optimizer']
criterion = params_dict['criterion']
# Choosing best device available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......
......@@ -91,7 +91,6 @@ def grid_overview(data, num_images=7, cmap_im="gray", cmap_segm="viridis", alpha
else:
ax.imshow(plot_data[col][row].squeeze(), cmap=cmap_im)
ax.axis("off")
fig.show()
......
......@@ -21,8 +21,8 @@ def plot_metrics(metric, color = 'blue', linestyle = '-', batch_linestyle = 'dot
Example:
train_loss = {'epoch_loss' : [...], 'batch_loss': [...]}
plot_metrics(train_loss, color = 'red', label='Train')
"""
# plotting parameters
snb.set_style('darkgrid')
snb.set(font_scale=1.5)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment