ngclearn.utils.density package
Submodules
ngclearn.utils.density.gmm module
- class ngclearn.utils.density.gmm.GMM(k, max_iter=5, assume_diag_cov=False, init_kmeans=True)[source]
Bases:
object
Implements a Gaussian mixture model (GMM) – or mixture of Gaussians, MoG. Adaptation of parameters is conducted via the Expectation-Maximization (EM) learning algorithm and leverages full covariance matrices in the component multivariate Gaussians.
Note this is a (JAX) wrapper model that houses the sklearn implementation for learning. The sampling process has been rewritten to utilize GPU matrix computation.
- Parameters:
k – the number of components/latent variables within this GMM
max_iter – the maximum number of EM iterations to fit parameters to data (Default = 5)
assume_diag_cov – if True, assumes a diagonal covariance for each component (Default = False)
init_kmeans – if True, first learn use the K-Means algorithm to initialize the component Gaussians of this GMM (Default = True)
- fit(data)[source]
Run full fitting process of this GMM.
- Parameters:
data – the dataset to fit this GMM to
- sample(n_s, mode_i=-1, samples_modes_evenly=False)[source]
(Efficiently) Draw samples from the current underlying GMM model
- Parameters:
n_s – the number of samples to draw from this GMM
mode_i – if >= 0, will only draw samples from a specific component of this GMM (Default = -1), ignoring the Categorical prior over latent variables/components
samples_modes_evenly – if True, will ignore the Categorical prior over latent variables/components and draw an approximately equal number of samples from each component