Stochastic optimization
This section is concerned with problems of the form
where \(f(x, \theta, D)\) is differentiable (almost everywhere), \(x\) are the parameters with respect to which the function is minimized, \(\theta\) are optional fixed extra arguments and \(D\) is a random variable (typically a mini-batch).
Examples
Defining an objective function
Objective functions must contain a data
argument corresponding to \(D\) above.
Example:
def ridge_reg_objective(params, l2reg, data):
X, y = data
residuals = jnp.dot(X, params) - y
return jnp.mean(residuals ** 2) + 0.5 * l2reg * jnp.dot(w ** 2)
Data iterator
Sampling realizations of the random variable \(D\) can be done using an iterator.
Example:
def data_iterator():
for _ in range(n_iter):
perm = rng.permutation(n_samples)[:batch_size]
yield (X[perm], y[perm])
Solvers
|
SGD with Armijo line search. |
|
Optax solver. |
|
SGD with Polyak step size. |
Optax solvers
Optax solvers can be used in JAXopt using
OptaxSolver
. Here’s an example with Adam:
from jaxopt import OptaxSolver
opt = optax.adam(learning_rate)
solver = OptaxSolver(opt=opt, fun=ridge_reg_objective, maxiter=1000)
See common optimizers in the optax documentation for a list of available stochastic solvers.
Adaptive solvers
Adaptive solvers update the step size at each iteration dynamically.
An example is PolyakSGD
, a solver
which computes step sizes adaptively using function values.
Another example is ArmijoSGD
, a solver
that uses an Armijo line search.
For convergence guarantees to hold, these two algorithms require the interpolation hypothesis to hold: the global optimum over \(D\) must also be a global optimum for any finite sample of \(D\). This is typically achieved by overparametrized models (e.g neural networks) in classification tasks with separable classes, or on regression tasks without noise.
Run iterator vs. manual loop
The following:
iterator = data_iterator()
solver.run_iterator(init_params, iterator, l2reg=l2reg)
is equivalent to:
iterator = data_iterator()
state = solver.init_state(init_params, l2reg=l2reg)
params = init_params
for _ in range(maxiter):
data = next(iterator)
params, state = solver.update(params, state, l2reg=l2reg, data=data)