from skimage.transform import SimilarityTransform, warp
import numpy as np
import os

def warpAndSave(driftEstimates, dataset, output_dir, dataset_type, interpolation_order=5, double_view_debug=False):

    driftEstimates[0,:] = 0 # correcting wrt. first image, which is not drifted wrt. anything

    cummulative_correction = np.cumsum(driftEstimates, axis=0)

    # calculate new dataset bounds due to drifting of the sections
    minx = -int(np.floor(np.min(-cummulative_correction[:,0])))
    maxx =  int(np.ceil( np.max(-cummulative_correction[:,0])))
    miny = -int(np.floor(np.min(-cummulative_correction[:,1])))
    maxy =  int(np.ceil( np.max(-cummulative_correction[:,1])))

    N = len(dataset)
    shape = dataset.shape

    h = shape[1] + (miny + maxy)
    if double_view_debug:
        w = (shape[2] + (minx + maxx))*2
    else:
        w = shape[2] + (minx + maxx)

    output_shape = (h, w)

    add_image = dataset.initializeDatasetOutput(output_dir, output_shape, dataset_type)

    for k in range(N):

        result = np.zeros((h, w), dtype=dataset.dtype)
        result[miny:miny+shape[1], minx:minx+shape[2]] = dataset[k]

        translation = cummulative_correction[k,:]

        tform = SimilarityTransform(scale=1, rotation=0, translation=translation)
        result = warp(result, tform, order=5)

        if double_view_debug:
            result[miny:miny+shape[1], (minx + w//2):(minx+shape[2] + w//2)] = dataset[k]

        dataset.saveTransformedImage(k, result.astype(dataset.dtype), dataset_type)