Skip to content
Snippets Groups Projects
Commit e0b98dd8 authored by monj's avatar monj
Browse files

Added function to plot grid of clusters

parent 825fb065
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ import numpy as np
import local_features as lf
import matplotlib.pyplot as plt
from math import ceil
#%% General functions
......@@ -83,7 +83,8 @@ def make_output_dirs(directory,subdirectories = False):
# os.makedirs(directory + disease + '/', exist_ok = True)
#Reads images starting with base_name from the subdirectories of the input directory
#Reads images starting with base_name from the subdirectories of the input directory.
#Option for reading scaled down versions and in bnw or colour
def read_max_imgs(dir_condition, base_name, sc_fac = 1, colour_mode = 'colour'):
'monj@dtu.dk'
......@@ -103,11 +104,16 @@ def read_max_imgs(dir_condition, base_name, sc_fac = 1, colour_mode = 'colour'):
#Option to load in bnw or colour
if colour_mode == 'bnw':
img = io.imread(frame_path, as_gray = True).astype('uint8')
frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True).astype('uint8')]
if sc_fac ==1:
frame_img_list += [img]
else:
frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True).astype('uint8')]
else:
img = io.imread(frame_path).astype('uint8')
#print(img.dtype)
frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True, multichannel=True).astype('uint8')]
if sc_fac == 1:
frame_img_list += [img]
else:
frame_img_list += [skimage.transform.rescale(img, sc_fac, preserve_range = True, multichannel=True).astype('uint8')]
max_img_list += [frame_img_list]
#print(frame_img_list[0].dtype)
......@@ -363,8 +369,35 @@ def ndim2col_pad(A, BSZ, stepsize=1, norm=False):
tmp = np.squeeze(tmp)
return ndim2col(tmp,BSZ,stepsize,norm)
#%% Functions for visualisation of learnt features
def plot_grid_cluster_centers(cluster_centers, cluster_order, patch_size, colour_mode = 'colour', occurrence = ''):
#grid dimensions
size_x = round(len(cluster_order)**(1/2))
size_y = ceil(len(cluster_order)/size_x)
#figure format
overhead = 1
w, h = plt.figaspect(size_x/size_y)
fig, axs = plt.subplots(size_x,size_y, figsize=(1.3*w,1.3*h*(1+overhead/2)), sharex=True, sharey=True)
#print('Grid size: ', grid_size[1], grid_size[2], 'Figure size: ', w, h)
ax_list = axs.ravel()
for ind, cluster in enumerate(cluster_order):
#print(ind)
if colour_mode == 'bnw': #in bnw + colour give the clusters a uniform colour
cluster_centre = np.reshape(cluster_centers[cluster,:],(patch_size,patch_size))
ax_list[ind].imshow(cluster_centre.astype('uint8'),cmap='gray')
else:
cluster_centre = np.transpose((np.reshape(cluster_centers[cluster,:],(3,patch_size,patch_size))),(1,2,0))
ax_list[ind].imshow(cluster_centre.astype('uint8'))
if occurrence !='':
ax_list[ind].set_title(round(occurrence[ind],2))
else:
ax_list[ind].set_title(cluster)
plt.setp(axs, xticks=[], yticks=[])
# def plot_mapsAndimages(dir_condition, directory_list, map_img, max_img_list, base_name = 'frame', r = 1024, c = 1024):
# nr_list = 0
# for directory in directory_list:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment