Source code for ngclearn.utils.matrix_utils

import jax.numpy as jnp

[docs] def decompose_to_mps(W, bond_dim=16): """ Decomposes a dense matrix W into two MPS cores using SVD. Args: W: The dense matrix to decompose of shape (in_dim, out_dim). bond_dim: The internal rank/bond-dimension of the MPS compression. Returns: A tuple containing: core1: First tensor core of shape (1, in_dim, bond_dim). core2: Second tensor core of shape (bond_dim, out_dim, 1). """ U, S, Vh = jnp.linalg.svd(W, full_matrices=False) k = min(bond_dim, len(S)) U_k = U[:, :k] S_k = S[:k] Vh_k = Vh[:k, :] s_sqrt = jnp.sqrt(S_k) core1 = (U_k * s_sqrt).reshape(1, W.shape[0], k) core2 = (s_sqrt[:, None] * Vh_k).reshape(k, W.shape[1], 1) return core1, core2