diff --git a/docs/notebooks/Unet.ipynb b/docs/notebooks/Unet.ipynb index 38a06fc9fd86f42a6abf4f7c384be051235d1205..9f975f02eaa185298e43a12fb3b23a61c166b25e 100644 --- a/docs/notebooks/Unet.ipynb +++ b/docs/notebooks/Unet.ipynb @@ -25,6 +25,16 @@ "%matplotlib inline" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "813a3454", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" + ] + }, { "cell_type": "code", "execution_count": null, @@ -74,7 +84,7 @@ "outputs": [], "source": [ "datasets = ['belialev2020_side', 'gaudez2022_3d', 'guo2023_2d', 'stan2020_2d', 'reichardt2021_2d', 'testcircles_2dbinary']\n", - "dataset = datasets[3] \n", + "dataset = datasets[-1] \n", "root = get_dataset_path(dataset,datasets)\n", "\n", "# should not use gaudez2022: 3d image\n", @@ -97,7 +107,7 @@ "outputs": [], "source": [ "# defining model\n", - "my_model = qim3d.models.UNet(size = 'medium', dropout = 0.25)\n", + "my_model = qim3d.models.UNet(size = 'small', dropout = 0.25)\n", "# defining augmentation\n", "my_aug = qim3d.utils.Augmentation(resize = 'crop', transform_train = 'light')" ] @@ -122,7 +132,7 @@ "\n", "# datasets and dataloaders\n", "train_set, val_set, test_set = qim3d.utils.prepare_datasets(path = root, model = my_model , augmentation = my_aug,\n", - " val_fraction = 0.3,test_fraction = 0.1,\n", + " val_fraction = 0.3, test_fraction = 0.1,\n", " train_folder='train', test_folder='test')\n", "\n", "train_loader, val_loader, test_loader = qim3d.utils.prepare_dataloaders(train_set, val_set,\n", @@ -167,7 +177,7 @@ "outputs": [], "source": [ "# model hyperparameters\n", - "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=5,\n", + "my_hyperparameters = qim3d.models.Hyperparameters(my_model, n_epochs=20,\n", " learning_rate = 5e-3, loss_function='DiceCE',weight_decay=1e-3)\n", "\n", "# training model\n", @@ -210,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.9.11" } }, "nbformat": 4, diff --git a/qim3d/utils/data.py b/qim3d/utils/data.py index e50a3b87b6c2c3bbcadbbf2baf2e45fb5b3ca547..846b13f6432e1830e9e120a7ce0e425cbe0217eb 100644 --- a/qim3d/utils/data.py +++ b/qim3d/utils/data.py @@ -171,7 +171,7 @@ class Dataset(torch.utils.data.Dataset): # if the first folder contains the targets: if any(ext in self.folder_names[0].lower() for ext in target_folder_names): - images = self.folders_names[1] + images = self.folder_names[1] targets = self.folder_names[0] # if the second folder contains the targets: