.. _unconstrained_optim:

Unconstrained optimization
==========================

This section is concerned with problems of the form

.. math::

    \min_{x} f(x, \theta)

where :math:`f(x, \theta)` is a differentiable (almost everywhere), :math:`x`
are the parameters with respect to which the function is minimized and
:math:`\theta` are optional extra arguments.

Defining an objective function
------------------------------

Objective functions must always include as first argument the variables with
respect to which the function is minimized. The function can also contain extra
arguments.

The following illustrates how to express the ridge regression objective::

  def ridge_reg_objective(params, l2reg, X, y):
    residuals = jnp.dot(X, params) - y
    return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.sum(params ** 2)

The model parameters ``params`` correspond to :math:`x` while ``l2reg``, ``X``
and ``y`` correspond to the extra arguments :math:`\theta` in the mathematical
notation above.

Solvers
-------

.. autosummary::
  :toctree: _autosummary

    jaxopt.BFGS
    jaxopt.GradientDescent
    jaxopt.LBFGS
    jaxopt.ScipyMinimize
    jaxopt.NonlinearCG

Instantiating and running the solver
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Continuing the ridge regression example above, gradient descent can be
instantiated and run as follows::

  solver = jaxopt.LBFGS(fun=ridge_reg_objective, maxiter=maxiter)
  res = solver.run(init_params, l2reg=l2reg, X=X, y=y)

  # Alternatively, we could have used one of these solvers as well:
  # solver = jaxopt.GradientDescent(fun=ridge_reg_objective, maxiter=500)
  # solver = jaxopt.ScipyMinimize(fun=ridge_reg_objective, method="L-BFGS-B", maxiter=500)
  # solver = jaxopt.NonlinearCG(fun=ridge_reg_objective, method="polak-ribiere", maxiter=500)

Unpacking results
~~~~~~~~~~~~~~~~~

Note that ``res`` has the form ``NamedTuple(params, state)``, where ``params``
are the approximate solution found by the solver and ``state`` contains
solver-specific information about convergence.

Because ``res`` is a ``NamedTuple``, we can unpack it as::

  params, state = res
  print(params, state)

Alternatively, we can also access attributes directly::

  print(res.params, res.state)