diff --git a/Chapter09/netEval.py b/Chapter09/netEval.py new file mode 100644 index 0000000000000000000000000000000000000000..b0a934f84c1173bc20ad3f2a6058567a80840e88 --- /dev/null +++ b/Chapter09/netEval.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 28 14:30:09 2023 + +@author: Pawel Pieta, papi@dtu.dk +""" + +import numpy as np +import matplotlib.pyplot as plt +import os +import glob +import skimage.io +import time +from collections.abc import Callable + +from sampleNetwork import netFunc + +def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]: + """ + Evaluates the accuracy of a classification model using provided dataset. + + Parameters + ---------- + netFunc - function of the network that takes an image and outputs a predicted class,\n + dataPath - path to a folder with data,\n + targetPath - path to a file with target labels (either .txt or .npy). + """ + + assert callable(netFunc), "The first argument is not callable, it should be a network function." + assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist." + assert os.path.exists(targetPath), f"Provided path: {targetPath} does not exist." + + ext = os.path.splitext(targetPath)[-1] + assert ext == '.txt' or ext == '.npy', f"Target path extension file {ext} is not supported, use .txt or .npy" + + if ext == '.txt': + with open(targetPath) as f: + targetList = np.array(f.readlines()).astype(int) + else: + targetList = np.load(targetPath) + + + # Read in the images + imgList = glob.glob(dataPath+'/*.png') + assert imgList, f"No .png images found in folder {targetPath}" + imgsArr = np.array([skimage.io.imread(fname) for fname in imgList]) + + # Execute network + t0 = time.time() + predList = netFunc(imgsArr) + t1 = time.time() + + # Calculate accuracy and execution time + accuracy = np.sum(np.equal(predList,targetList))/len(targetList) + execTime = t1-t0 + + return accuracy, execTime + + + +if __name__ == "__main__": + + targetPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train_targets.txt' + dataPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train' + + accuracy, execTime = netEval(netFunc, dataPath, targetPath) + + print(f"Achieved accuracy: {np.round(accuracy,4)}") + print(f"Network execution time: {np.round(execTime,4)}s") \ No newline at end of file diff --git a/Chapter09/sampleNetwork.py b/Chapter09/sampleNetwork.py new file mode 100644 index 0000000000000000000000000000000000000000..e659b5fc762c1241f0a401a1d7abcdfe6d2dc4e2 --- /dev/null +++ b/Chapter09/sampleNetwork.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 28 14:31:20 2023 + +@author: Pawel Pieta, papi@dtu.dk +""" +import numpy as np +from collections.abc import Iterable + +def netFunc(images: np.array) -> Iterable[int]: + ''' + Loads a feed forward neural network, and predicts the labels of each image in an array. + + Parameters + ---------- + images : numpy array + An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)", + float datatype with pixel values between 0 and 1. The array has \n + dimension N x H x W x C where N is the number of images, H is the height of the images \n + , W is the width of the images and C is the number of channels in each image. + + Returns + ------- + class_indices : Iterable[int] + An iterable (e.g., a list) of the class indices representing the predictions of the network. The valid indices range from 0 to 11. + + ''' + + # Specify here the main path to where you keep your model weights, (in case we need to modify it) + # try to make the path relative to this code file location, + # e.g "../weights/v1/" instead of "C:/AdvImg/week10/weights/v1" + MODEL_PATH = "../weights/v1/" + + + # Change this to return a list of indices predicted by your model + class_indices = np.random.randint(0,12, size=images.shape[0]) + + return class_indices + +