From 8a0ea7e5db8696a466e01d0b93d996e528afe8e8 Mon Sep 17 00:00:00 2001 From: Pawel Pieta <papi@dtu.dk> Date: Wed, 29 Mar 2023 14:48:29 +0200 Subject: [PATCH] Type hints removed for better Python version compatibility --- Chapter09/netEval.py | 23 ++++++++++++++++------- Chapter09/sampleNetwork.py | 7 +++---- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/Chapter09/netEval.py b/Chapter09/netEval.py index b0a934f..6f3af21 100644 --- a/Chapter09/netEval.py +++ b/Chapter09/netEval.py @@ -11,20 +11,29 @@ import os import glob import skimage.io import time -from collections.abc import Callable from sampleNetwork import netFunc -def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]: - """ +def netEval(netFunc, dataPath, targetPath): + ''' Evaluates the accuracy of a classification model using provided dataset. Parameters ---------- - netFunc - function of the network that takes an image and outputs a predicted class,\n - dataPath - path to a folder with data,\n - targetPath - path to a file with target labels (either .txt or .npy). - """ + netFunc : function + Function of the network that takes an image and outputs a predicted class + dataPath: string + Path to a folder with data + targetPath: string + Path to a file with target labels (either .txt or .npy) + + Returns + ------- + accuracy: float + Accuracy of the network on provided dataset + execTime: + Network prediction execution time + ''' assert callable(netFunc), "The first argument is not callable, it should be a network function." assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist." diff --git a/Chapter09/sampleNetwork.py b/Chapter09/sampleNetwork.py index e659b5f..76c1bed 100644 --- a/Chapter09/sampleNetwork.py +++ b/Chapter09/sampleNetwork.py @@ -5,9 +5,8 @@ Created on Tue Mar 28 14:31:20 2023 @author: Pawel Pieta, papi@dtu.dk """ import numpy as np -from collections.abc import Iterable -def netFunc(images: np.array) -> Iterable[int]: +def netFunc(images): ''' Loads a feed forward neural network, and predicts the labels of each image in an array. @@ -16,8 +15,8 @@ def netFunc(images: np.array) -> Iterable[int]: images : numpy array An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)", float datatype with pixel values between 0 and 1. The array has \n - dimension N x H x W x C where N is the number of images, H is the height of the images \n - , W is the width of the images and C is the number of channels in each image. + dimension N x H x W x C where N is the number of images, H is the height of the images, \n + W is the width of the images and C is the number of channels in each image. Returns ------- -- GitLab