Skip to content
Snippets Groups Projects
Commit 8a0ea7e5 authored by papi's avatar papi
Browse files

Type hints removed for better Python version compatibility

parent c737d7f3
No related branches found
No related tags found
No related merge requests found
...@@ -11,20 +11,29 @@ import os ...@@ -11,20 +11,29 @@ import os
import glob import glob
import skimage.io import skimage.io
import time import time
from collections.abc import Callable
from sampleNetwork import netFunc from sampleNetwork import netFunc
def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]: def netEval(netFunc, dataPath, targetPath):
""" '''
Evaluates the accuracy of a classification model using provided dataset. Evaluates the accuracy of a classification model using provided dataset.
Parameters Parameters
---------- ----------
netFunc - function of the network that takes an image and outputs a predicted class,\n netFunc : function
dataPath - path to a folder with data,\n Function of the network that takes an image and outputs a predicted class
targetPath - path to a file with target labels (either .txt or .npy). dataPath: string
""" Path to a folder with data
targetPath: string
Path to a file with target labels (either .txt or .npy)
Returns
-------
accuracy: float
Accuracy of the network on provided dataset
execTime:
Network prediction execution time
'''
assert callable(netFunc), "The first argument is not callable, it should be a network function." assert callable(netFunc), "The first argument is not callable, it should be a network function."
assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist." assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist."
......
...@@ -5,9 +5,8 @@ Created on Tue Mar 28 14:31:20 2023 ...@@ -5,9 +5,8 @@ Created on Tue Mar 28 14:31:20 2023
@author: Pawel Pieta, papi@dtu.dk @author: Pawel Pieta, papi@dtu.dk
""" """
import numpy as np import numpy as np
from collections.abc import Iterable
def netFunc(images: np.array) -> Iterable[int]: def netFunc(images):
''' '''
Loads a feed forward neural network, and predicts the labels of each image in an array. Loads a feed forward neural network, and predicts the labels of each image in an array.
...@@ -16,8 +15,8 @@ def netFunc(images: np.array) -> Iterable[int]: ...@@ -16,8 +15,8 @@ def netFunc(images: np.array) -> Iterable[int]:
images : numpy array images : numpy array
An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)", An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)",
float datatype with pixel values between 0 and 1. The array has \n float datatype with pixel values between 0 and 1. The array has \n
dimension N x H x W x C where N is the number of images, H is the height of the images \n dimension N x H x W x C where N is the number of images, H is the height of the images, \n
, W is the width of the images and C is the number of channels in each image. W is the width of the images and C is the number of channels in each image.
Returns Returns
------- -------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment