Skip to content
Snippets Groups Projects
ex4_1_7.py 1.87 KiB
Newer Older
  • Learn to ignore specific revisions
  • bjje's avatar
    bjje committed
    # exercise 4.1.7
    
    import importlib_resources
    import numpy as np
    import matplotlib.pyplot as plt 
    from scipy.io import loadmat
    
    filename = importlib_resources.files("dtuimldmtools").joinpath("data/zipdata.mat")
    # Digits to include in analysis (to include all, n = range(10) )
    n = [1]
    
    # Number of digits to generate from normal distributions
    ngen = 10
    
    # Load Matlab data file to python dict structure
    # and extract variables of interest
    traindata = loadmat(filename)["traindata"]
    X = traindata[:, 1:]
    y = traindata[:, 0]
    N, M = np.shape(X)  # or X.shape
    C = len(n)
    
    # Remove digits that are not to be inspected
    class_mask = np.zeros(N).astype(bool)
    for v in n:
        cmsk = y == v
        class_mask = class_mask | cmsk
    X = X[class_mask, :]
    y = y[class_mask]
    N = np.shape(X)[0]  # or X.shape[0]
    
    mu = X.mean(axis=0)
    s = X.std(ddof=1, axis=0)
    S = np.cov(X, rowvar=0, ddof=1)
    
    # Generate 10 samples from 1-D normal distribution
    Xgen = np.random.randn(ngen, 256)
    for i in range(ngen):
        Xgen[i] = np.multiply(Xgen[i], s) + mu
    
    # Plot images
    plt.figure()
    for k in range(ngen):
        plt.subplot(2, int(np.ceil(ngen / 2.0)), k + 1)
        I = np.reshape(Xgen[k, :], (16, 16))
        plt.imshow(I, cmap=plt.cm.gray_r)
        plt.xticks([])
        plt.yticks([])
        if k == 1:
            plt.title("Digits: 1-D Normal")
    
    
    # Generate 10 samples from multivariate normal distribution
    Xmvgen = np.random.multivariate_normal(mu, S, ngen)
    # Note if you are investigating a single class, then you may get:
    # """RuntimeWarning: covariance is not positive-semidefinite."""
    # Which in general is troublesome, but here is due to numerical imprecission
    
    
    # Plot images
    plt.figure()
    for k in range(ngen):
        plt.subplot(2, int(np.ceil(ngen / 2.0)), k + 1)
        I = np.reshape(Xmvgen[k, :], (16, 16))
        plt.imshow(I, cmap=plt.cm.gray_r)
        plt.xticks([])
        plt.yticks([])
        if k == 1:
            plt.title("Digits: Multivariate Normal")
    
    plt.show()