Non-negative matrix factorizaton (NMF) using alternating minimization.

from absl import app
from absl import flags

import jax.numpy as jnp

from jaxopt import BlockCoordinateDescent
from jaxopt import objective
from jaxopt import prox

import numpy as onp

from sklearn import datasets


flags.DEFINE_string("penalty", "l2", "Regularization type.")
flags.DEFINE_float("gamma", 1.0, "Regularization strength.")
FLAGS = flags.FLAGS


def nnreg(U, V_init, X, maxiter=150):
  """Regularized non-negative regression.

  We solve::

  min_{V >= 0} mean((U V^T - X) ** 2) + 0.5 * gamma * ||V||^2_2

  or

  min_{V >= 0} mean((U V^T - X) ** 2) +  gamma * ||V||_1
  """
  if FLAGS.penalty == "l2":
    block_prox = prox.prox_non_negative_ridge
  elif FLAGS.penalty == "l1":
    block_prox = prox.prox_non_negative_lasso
  else:
    raise ValueError("Invalid penalty.")

  bcd = BlockCoordinateDescent(fun=objective.least_squares,
                               block_prox=block_prox,
                               maxiter=maxiter)
  sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X))
  return sol.params.T  # approximate solution V


def reconstruction_error(U, V, X):
  """Computes (unregularized) reconstruction error."""
  UV = jnp.dot(U, V.T)
  return 0.5 * jnp.mean((UV - X) ** 2)


def nmf(U_init, V_init, X, maxiter=10):
  """NMF by alternating minimization.

  We solve

    min_{U >= 0, V>= 0} ||U V^T - X||^2 + 0.5 * gamma * (||U||^2_2 + ||V||^2_2)

  or

    min_{U >= 0, V>= 0} ||U V^T - X||^2 + gamma * (||U||_1 + ||V||_1)
  """
  U, V = U_init, V_init

  error = reconstruction_error(U, V, X)
  print(f"STEP: 0; Error: {error:.3f}")
  print()

  for step in range(1, maxiter + 1):
    print(f"STEP: {step}")

    V = nnreg(U, V, X, maxiter=150)
    error = reconstruction_error(U, V, X)
    print(f"Error: {error:.3f} (V update)")

    U = nnreg(V, U, X.T, maxiter=150)
    error = reconstruction_error(U, V, X)
    print(f"Error: {error:.3f} (U update)")
    print()


def main(argv):
  del argv

  # Prepare data.
  X, _ = datasets.load_boston(return_X_y=True)
  X = jnp.sqrt(X ** 2)

  n_samples = X.shape[0]
  n_features = X.shape[1]
  n_components = 10

  rng = onp.random.RandomState(0)
  U = jnp.array(rng.rand(n_samples, n_components))
  V = jnp.array(rng.rand(n_features, n_components))

  # Run the algorithm.
  print("penalty:", FLAGS.penalty)
  print("gamma", FLAGS.gamma)
  print()

  nmf(U, V, X, maxiter=30)

if __name__ == "__main__":
  app.run(main)

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery