diff --git a/Chapter09/netEval.py b/Chapter09/netEval.py index b0a934f84c1173bc20ad3f2a6058567a80840e88..6f3af21e5a0269494926f1caf521e5f006c30166 100644 --- a/Chapter09/netEval.py +++ b/Chapter09/netEval.py @@ -11,20 +11,29 @@ import os import glob import skimage.io import time -from collections.abc import Callable from sampleNetwork import netFunc -def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]: - """ +def netEval(netFunc, dataPath, targetPath): + ''' Evaluates the accuracy of a classification model using provided dataset. Parameters ---------- - netFunc - function of the network that takes an image and outputs a predicted class,\n - dataPath - path to a folder with data,\n - targetPath - path to a file with target labels (either .txt or .npy). - """ + netFunc : function + Function of the network that takes an image and outputs a predicted class + dataPath: string + Path to a folder with data + targetPath: string + Path to a file with target labels (either .txt or .npy) + + Returns + ------- + accuracy: float + Accuracy of the network on provided dataset + execTime: + Network prediction execution time + ''' assert callable(netFunc), "The first argument is not callable, it should be a network function." assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist." diff --git a/Chapter09/sampleNetwork.py b/Chapter09/sampleNetwork.py index e659b5fc762c1241f0a401a1d7abcdfe6d2dc4e2..76c1bedb79f4e07b7ec6b29cc371a3bed5d48c3d 100644 --- a/Chapter09/sampleNetwork.py +++ b/Chapter09/sampleNetwork.py @@ -5,9 +5,8 @@ Created on Tue Mar 28 14:31:20 2023 @author: Pawel Pieta, papi@dtu.dk """ import numpy as np -from collections.abc import Iterable -def netFunc(images: np.array) -> Iterable[int]: +def netFunc(images): ''' Loads a feed forward neural network, and predicts the labels of each image in an array. @@ -16,8 +15,8 @@ def netFunc(images: np.array) -> Iterable[int]: images : numpy array An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)", float datatype with pixel values between 0 and 1. The array has \n - dimension N x H x W x C where N is the number of images, H is the height of the images \n - , W is the width of the images and C is the number of channels in each image. + dimension N x H x W x C where N is the number of images, H is the height of the images, \n + W is the width of the images and C is the number of channels in each image. Returns -------