Note
Click here to download the full example code
Non-negative matrix factorizaton (NMF) using alternating minimization.
from absl import app
from absl import flags
import jax.numpy as jnp
from jaxopt import BlockCoordinateDescent
from jaxopt import objective
from jaxopt import prox
import numpy as onp
from sklearn import datasets
flags.DEFINE_string("penalty", "l2", "Regularization type.")
flags.DEFINE_float("gamma", 1.0, "Regularization strength.")
FLAGS = flags.FLAGS
def nnreg(U, V_init, X, maxiter=150):
"""Regularized non-negative regression.
We solve::
min_{V >= 0} mean((U V^T - X) ** 2) + 0.5 * gamma * ||V||^2_2
or
min_{V >= 0} mean((U V^T - X) ** 2) + gamma * ||V||_1
"""
if FLAGS.penalty == "l2":
block_prox = prox.prox_non_negative_ridge
elif FLAGS.penalty == "l1":
block_prox = prox.prox_non_negative_lasso
else:
raise ValueError("Invalid penalty.")
bcd = BlockCoordinateDescent(fun=objective.least_squares,
block_prox=block_prox,
maxiter=maxiter)
sol = bcd.run(init_params=V_init.T, hyperparams_prox=FLAGS.gamma, data=(U, X))
return sol.params.T # approximate solution V
def reconstruction_error(U, V, X):
"""Computes (unregularized) reconstruction error."""
UV = jnp.dot(U, V.T)
return 0.5 * jnp.mean((UV - X) ** 2)
def nmf(U_init, V_init, X, maxiter=10):
"""NMF by alternating minimization.
We solve
min_{U >= 0, V>= 0} ||U V^T - X||^2 + 0.5 * gamma * (||U||^2_2 + ||V||^2_2)
or
min_{U >= 0, V>= 0} ||U V^T - X||^2 + gamma * (||U||_1 + ||V||_1)
"""
U, V = U_init, V_init
error = reconstruction_error(U, V, X)
print(f"STEP: 0; Error: {error:.3f}")
print()
for step in range(1, maxiter + 1):
print(f"STEP: {step}")
V = nnreg(U, V, X, maxiter=150)
error = reconstruction_error(U, V, X)
print(f"Error: {error:.3f} (V update)")
U = nnreg(V, U, X.T, maxiter=150)
error = reconstruction_error(U, V, X)
print(f"Error: {error:.3f} (U update)")
print()
def main(argv):
del argv
# Prepare data.
X, _ = datasets.load_boston(return_X_y=True)
X = jnp.sqrt(X ** 2)
n_samples = X.shape[0]
n_features = X.shape[1]
n_components = 10
rng = onp.random.RandomState(0)
U = jnp.array(rng.rand(n_samples, n_components))
V = jnp.array(rng.rand(n_features, n_components))
# Run the algorithm.
print("penalty:", FLAGS.penalty)
print("gamma", FLAGS.gamma)
print()
nmf(U, V, X, maxiter=30)
if __name__ == "__main__":
app.run(main)
Total running time of the script: ( 0 minutes 0.000 seconds)