Unconstrained optimization
This section is concerned with problems of the form
where \(f(x, \theta)\) is a differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized and \(\theta\) are optional extra arguments.
Defining an objective function
Objective functions must always include as first argument the variables with respect to which the function is minimized. The function can also contain extra arguments.
The following illustrates how to express the ridge regression objective:
def ridge_reg_objective(params, l2reg, X, y):
residuals = jnp.dot(X, params) - y
return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2)
The model parameters params
correspond to \(x\) while l2reg
, X
and y
correspond to the extra arguments \(\theta\) in the mathematical
notation above.
Solvers
|
BFGS solver. |
|
Gradient Descent solver. |
|
LBFGS solver. |
|
scipy.optimize.minimize wrapper |
|
Nonlinear conjugate gradient solver. |
Instantiating and running the solver
Continuing the ridge regression example above, gradient descent can be instantiated and run as follows:
solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter)
res = solver.run(init_params, l2reg=l2reg, X=X, y=y)
# Alternatively, we could have used one of these solvers as well:
# solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500)
# solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500)
# solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500)
Unpacking results
Note that res
has the form NamedTuple(params, state)
, where params
are the approximate solution found by the solver and state
contains
solver-specific information about convergence.
Because res
is a NamedTuple
, we can unpack it as:
params, state = res
print(params, state)
Alternatively, we can also access attributes directly:
print(res.params, res.state)