Skip to content
Snippets Groups Projects
Commit c737d7f3 authored by papi's avatar papi
Browse files

BugNIST2D challenge submission helpers

parent 5b4d8b6e
No related merge requests found
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 28 14:30:09 2023
@author: Pawel Pieta, papi@dtu.dk
"""
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
import skimage.io
import time
from collections.abc import Callable
from sampleNetwork import netFunc
def netEval(netFunc: Callable[np.array], dataPath: str, targetPath: str) -> tuple[float,float]:
"""
Evaluates the accuracy of a classification model using provided dataset.
Parameters
----------
netFunc - function of the network that takes an image and outputs a predicted class,\n
dataPath - path to a folder with data,\n
targetPath - path to a file with target labels (either .txt or .npy).
"""
assert callable(netFunc), "The first argument is not callable, it should be a network function."
assert os.path.exists(dataPath), f"Provided path: {dataPath} does not exist."
assert os.path.exists(targetPath), f"Provided path: {targetPath} does not exist."
ext = os.path.splitext(targetPath)[-1]
assert ext == '.txt' or ext == '.npy', f"Target path extension file {ext} is not supported, use .txt or .npy"
if ext == '.txt':
with open(targetPath) as f:
targetList = np.array(f.readlines()).astype(int)
else:
targetList = np.load(targetPath)
# Read in the images
imgList = glob.glob(dataPath+'/*.png')
assert imgList, f"No .png images found in folder {targetPath}"
imgsArr = np.array([skimage.io.imread(fname) for fname in imgList])
# Execute network
t0 = time.time()
predList = netFunc(imgsArr)
t1 = time.time()
# Calculate accuracy and execution time
accuracy = np.sum(np.equal(predList,targetList))/len(targetList)
execTime = t1-t0
return accuracy, execTime
if __name__ == "__main__":
targetPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train_targets.txt'
dataPath = 'C:/DTU_local/AdvancedImageAnalysis/Data/bugNIST2D/train'
accuracy, execTime = netEval(netFunc, dataPath, targetPath)
print(f"Achieved accuracy: {np.round(accuracy,4)}")
print(f"Network execution time: {np.round(execTime,4)}s")
\ No newline at end of file
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 28 14:31:20 2023
@author: Pawel Pieta, papi@dtu.dk
"""
import numpy as np
from collections.abc import Iterable
def netFunc(images: np.array) -> Iterable[int]:
'''
Loads a feed forward neural network, and predicts the labels of each image in an array.
Parameters
----------
images : numpy array
An array with grayscale images loaded with "skimage.io.imread(fname,as_gray=True)",
float datatype with pixel values between 0 and 1. The array has \n
dimension N x H x W x C where N is the number of images, H is the height of the images \n
, W is the width of the images and C is the number of channels in each image.
Returns
-------
class_indices : Iterable[int]
An iterable (e.g., a list) of the class indices representing the predictions of the network. The valid indices range from 0 to 11.
'''
# Specify here the main path to where you keep your model weights, (in case we need to modify it)
# try to make the path relative to this code file location,
# e.g "../weights/v1/" instead of "C:/AdvImg/week10/weights/v1"
MODEL_PATH = "../weights/v1/"
# Change this to return a list of indices predicted by your model
class_indices = np.random.randint(0,12, size=images.shape[0])
return class_indices
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment