.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/constrained/nmf.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_constrained_nmf.py: Non-negative matrix factorizaton (NMF) using alternating minimization. ====================================================================== .. GENERATED FROM PYTHON SOURCE LINES 19-124 .. code-block:: default from absl import app from absl import flags import jax.numpy as jnp from jaxopt import BlockCoordinateDescent from jaxopt import objective from jaxopt import prox import numpy as onp from sklearn import datasets flags.DEFINE_string("penalty", "l2", "Regularization type.") flags.DEFINE_float("gamma", 1.0, "Regularization strength.") FLAGS = flags.FLAGS def nnreg(U, V_init, X, maxiter=150): """Regularized non-negative regression. We solve:: min_{V >= 0} mean((U V^T - X) ** 2) + 0.5 * gamma * ||V||^2_2 or min_{V >= 0} mean((U V^T - X) ** 2) + gamma * ||V||_1 """ if FLAGS.penalty == "l2": block_prox = prox.prox_non_negative_ridge elif FLAGS.penalty == "l1": block_prox = prox.prox_non_negative_lasso else: raise ValueError("Invalid penalty.") bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=block_prox, maxiter=maxiter) sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X)) return sol.params.T # approximate solution V def reconstruction_error(U, V, X): """Computes (unregularized) reconstruction error.""" UV = jnp.dot(U, V.T) return 0.5 * jnp.mean((UV - X) ** 2) def nmf(U_init, V_init, X, maxiter=10): """NMF by alternating minimization. We solve min_{U >= 0, V>= 0} ||U V^T - X||^2 + 0.5 * gamma * (||U||^2_2 + ||V||^2_2) or min_{U >= 0, V>= 0} ||U V^T - X||^2 + gamma * (||U||_1 + ||V||_1) """ U, V = U_init, V_init error = reconstruction_error(U, V, X) print(f"STEP: 0; Error: {error:.3f}") print() for step in range(1, maxiter + 1): print(f"STEP: {step}") V = nnreg(U, V, X, maxiter=150) error = reconstruction_error(U, V, X) print(f"Error: {error:.3f} (V update)") U = nnreg(V, U, X.T, maxiter=150) error = reconstruction_error(U, V, X) print(f"Error: {error:.3f} (U update)") print() def main(argv): del argv # Prepare data. X, _ = datasets.load_boston(return_X_y=True) X = jnp.sqrt(X ** 2) n_samples = X.shape[0] n_features = X.shape[1] n_components = 10 rng = onp.random.RandomState(0) U = jnp.array(rng.rand(n_samples, n_components)) V = jnp.array(rng.rand(n_features, n_components)) # Run the algorithm. print("penalty:", FLAGS.penalty) print("gamma", FLAGS.gamma) print() nmf(U, V, X, maxiter=30) if __name__ == "__main__": app.run(main) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_auto_examples_constrained_nmf.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: nmf.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: nmf.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_