Skip to content
Snippets Groups Projects
Select Git revision
  • 8a0ea7e5db8696a466e01d0b93d996e528afe8e8
  • main default protected
2 results

netEval.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    netEval.py 2.27 KiB
    # -*- coding: utf-8 -*-
    """
    Created on Tue Mar 28 14:30:09 2023
    
    @author: Pawel Pieta, papi@dtu.dk
    """
    
    import numpy as np
    import matplotlib.pyplot as plt
    import os
    import glob
    import skimage.io
    import time
    
    from sampleNetwork import netFunc
    
    def netEval(netFunc, dataPath, targetPath):
        '''
        Evaluates the accuracy of a classification model using provided dataset.
        
        Parameters
        ----------
        netFunc : function
            Function of the network that takes an image and outputs a predicted class
        dataPath: string
            Path to a folder with data
        targetPath: string 
            Path to a file with target labels (either .txt or .npy)
        
        Returns
        -------
        accuracy: float
            Accuracy of the network on provided dataset
        execTime:
            Network prediction execution time
        '''
        
        assert callable(netFunc), "The first argument is not callable, it should be a network function."
        assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist."
        assert os.path.exists(targetPath), f"Provided path: {targetPath} does not exist."
        
        ext = os.path.splitext(targetPath)[-1]
        assert ext == '.txt' or ext == '.npy', f"Target path extension file {ext} is not supported, use .txt or .npy"
        
        if ext == '.txt':
            with open(targetPath) as f:
                targetList = np.array(f.readlines()).astype(int)
        else:
            targetList = np.load(targetPath)
        
        
        # Read in the images
        imgList = glob.glob(dataPath+'/*.png')
        assert imgList, f"No .png images found in folder {targetPath}"
        imgsArr = np.array([skimage.io.imread(fname) for fname in imgList])
        
        # Execute network
        t0 = time.time()
        predList = netFunc(imgsArr)
        t1 = time.time()
        
        # Calculate accuracy and execution time
        accuracy = np.sum(np.equal(predList,targetList))/len(targetList)
        execTime = t1-t0
        
        return accuracy, execTime
            
        
    
    if __name__ == "__main__":
        
        targetPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train_targets.txt'
        dataPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train' 
        
        accuracy, execTime = netEval(netFunc, dataPath, targetPath)
        
        print(f"Achieved accuracy: {np.round(accuracy,4)}")
        print(f"Network execution time: {np.round(execTime,4)}s")