-
Save prachiisc/a5f193212f90c8f82e84b5fe9dcf85d4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
# Code to generate 2d Gaussian distribution function, contour, generated samples in python | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from mpl_toolkits.mplot3d import Axes3D | |
# Our 2-dimensional distribution will be over variables X and Y | |
N = 40 | |
X = np.linspace(-2, 2, N) | |
Y = np.linspace(-2, 2, N) | |
X, Y = np.meshgrid(X, Y) | |
# Mean vector and covariance matrix | |
mu = np.array([0., 0.]) | |
Sigma = np.array([[ 1. , -0.8], [-0.8, 1.]]) | |
# Pack X and Y into a single 3-dimensional array | |
pos = np.empty(X.shape + (2,)) | |
pos[:, :, 0] = X | |
pos[:, :, 1] = Y | |
def multivariate_gaussian(pos, mu, Sigma): | |
"""Return the multivariate Gaussian distribution on array pos.""" | |
n = mu.shape[0] | |
Sigma_det = np.linalg.det(Sigma) | |
Sigma_inv = np.linalg.inv(Sigma) | |
N = np.sqrt((2*np.pi)**n * Sigma_det) | |
# This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized | |
# way across all the input variables. | |
fac = np.einsum('...k,kl,...l->...', pos-mu, Sigma_inv, pos-mu) | |
return np.exp(-fac / 2) / N | |
# The distribution on the variables X, Y packed into pos. | |
Z = multivariate_gaussian(pos, mu, Sigma) | |
# plot using subplots | |
fig = plt.figure() | |
ax1 = fig.add_subplot(3,1,1,projection='3d') | |
ax1.plot_surface(X, Y, Z, rstride=3, cstride=3, linewidth=1, antialiased=True, | |
cmap=cm.viridis) | |
ax1.view_init(55,90) | |
ax1.set_xticks([]) | |
ax1.set_yticks([]) | |
ax1.set_zticks([]) | |
ax1.set_title('pdf of Gaussian Distribution') | |
ax1.set_xlabel(r'$x_1$') | |
ax1.set_ylabel(r'$x_2$') | |
ax1.set_zlabel(r'$p(x)$') | |
ax2 = fig.add_subplot(3,1,2,projection='3d') | |
ax2.contourf(X, Y, Z, zdir='z', offset=0, cmap=cm.viridis) | |
ax2.view_init(90, 90) | |
ax2.grid(False) | |
ax2.set_xticks([]) | |
ax2.set_yticks([]) | |
ax2.set_zticks([]) | |
ax2.set_title(r'$Contour$') | |
ax2.set_xlabel(r'$x_1$') | |
ax2.set_ylabel(r'$x_2$') | |
# generate samples | |
ax3 = fig.add_subplot(3,1,3) | |
x, y = np.random.multivariate_normal(mu, Sigma, 1000).T | |
ax3.plot(x, y, 'x') | |
ax3.plot(0,0,'x',c='k') | |
# ax3.set_xticks([]) | |
# ax3.set_yticks([]) | |
ax3.set_title('Generated samples') | |
ax3.set_xlabel(r'$x_1$') | |
ax3.set_ylabel(r'$x_2$') | |
plt.show() |
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment