Skip to content
Snippets Groups Projects
Commit c737d7f3 authored by papi's avatar papi
Browse files

BugNIST2D challenge submission helpers

parent 5b4d8b6e
Branches
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 28 14:30:09 2023
@author: Pawel Pieta, papi@dtu.dk
"""
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import skimage.io
import time
from collections.abc import Callable
from sampleNetwork import netFunc
def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]:
"""
Evaluates the accuracy of a classification model using provided dataset.
Parameters
----------
netFunc - function of the network that takes an image and outputs a predicted class,\n
dataPath - path to a folder with data,\n
targetPath - path to a file with target labels (either .txt or .npy).
"""
assert callable(netFunc), "The first argument is not callable, it should be a network function."
assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist."
assert os.path.exists(targetPath), f"Provided path: {targetPath} does not exist."
ext = os.path.splitext(targetPath)[-1]
assert ext == '.txt' or ext == '.npy', f"Target path extension file {ext} is not supported, use .txt or .npy"
if ext == '.txt':
with open(targetPath) as f:
targetList = np.array(f.readlines()).astype(int)
else:
targetList = np.load(targetPath)
# Read in the images
imgList = glob.glob(dataPath+'/*.png')
assert imgList, f"No .png images found in folder {targetPath}"
imgsArr = np.array([skimage.io.imread(fname) for fname in imgList])
# Execute network
t0 = time.time()
predList = netFunc(imgsArr)
t1 = time.time()
# Calculate accuracy and execution time
accuracy = np.sum(np.equal(predList,targetList))/len(targetList)
execTime = t1-t0
return accuracy, execTime
if __name__ == "__main__":
targetPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train_targets.txt'
dataPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train'
accuracy, execTime = netEval(netFunc, dataPath, targetPath)
print(f"Achieved accuracy: {np.round(accuracy,4)}")
print(f"Network execution time: {np.round(execTime,4)}s")
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 28 14:31:20 2023
@author: Pawel Pieta, papi@dtu.dk
"""
import numpy as np
from collections.abc import Iterable
def netFunc(images: np.array) -> Iterable[int]:
'''
Loads a feed forward neural network, and predicts the labels of each image in an array.
Parameters
----------
images : numpy array
An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)",
float datatype with pixel values between 0 and 1. The array has \n
dimension N x H x W x C where N is the number of images, H is the height of the images \n
, W is the width of the images and C is the number of channels in each image.
Returns
-------
class_indices : Iterable[int]
An iterable (e.g., a list) of the class indices representing the predictions of the network. The valid indices range from 0 to 11.
'''
# Specify here the main path to where you keep your model weights, (in case we need to modify it)
# try to make the path relative to this code file location,
# e.g "../weights/v1/" instead of "C:/AdvImg/week10/weights/v1"
MODEL_PATH = "../weights/v1/"
# Change this to return a list of indices predicted by your model
class_indices = np.random.randint(0,12, size=images.shape[0])
return class_indices
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment