diff --git a/qim3d/ml/models/__init__.py b/qim3d/ml/models/__init__.py index 4624be5fb111dac4b820b52f5ab072bf744c9154..65493e60d531f65687895dddff1c38fea11c5271 100644 --- a/qim3d/ml/models/__init__.py +++ b/qim3d/ml/models/__init__.py @@ -1 +1 @@ -from ._unet import UNet, Hyperparameters +from ._unet import UNet2D, Hyperparameters diff --git a/qim3d/ml/models/_unet.py b/qim3d/ml/models/_unet.py index fa41c0214d3dedccc916f79812f08e06b1eb03a2..40006ad4ffa8f5aa2814747dfb7957ea5cdcf94e 100644 --- a/qim3d/ml/models/_unet.py +++ b/qim3d/ml/models/_unet.py @@ -5,7 +5,7 @@ import torch.nn as nn from qim3d.utils import log -class UNet(nn.Module): +class UNet2D(nn.Module): """ 2D UNet model for QIM imaging.