jaxopt.Bisection
- class jaxopt.Bisection(optimality_fun, lower, upper, maxiter=30, tol=1e-05, check_bracket=True, verbose=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]
One-dimensional root finding using bisection.
- Parameters
optimality_fun (Callable) –
lower (float) –
upper (float) –
maxiter (int) –
tol (float) –
check_bracket (bool) –
verbose (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- optimality_fun
a function
optimality_fun(x, *args, **kwargs)
wherex
is a 1d variable. The function should have opposite signs when evaluated atlower
and atupper
.- Type
Callable
- lower
the lower end of the bracketing interval.
- Type
float
- upper
the upper end of the bracketing interval.
- Type
float
- maxiter
maximum number of iterations.
- Type
int
- tol
tolerance.
- Type
float
- check_bracket
whether to check correctness of the bracketing interval. If True, the method
run
cannot be jitted.- Type
bool
- implicit_diff_solve
the linear system solver to use.
- Type
Optional[Callable]
- verbose
whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.
- Type
bool
- jit
whether to JIT-compile the bisection loop (default: “auto”).
- Type
Union[str, bool]
- unroll
whether to unroll the bisection loop (default: “auto”).
- Type
Union[str, bool]
- __init__(optimality_fun, lower, upper, maxiter=30, tol=1e-05, check_bracket=True, verbose=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
- Parameters
optimality_fun (Callable) –
lower (float) –
upper (float) –
maxiter (int) –
tol (float) –
check_bracket (bool) –
verbose (bool) –
implicit_diff_solve (Optional[Callable]) –
has_aux (bool) –
jit (Union[str, bool]) –
unroll (Union[str, bool]) –
- Return type
None
Methods
__init__
(optimality_fun, lower, upper[, ...])attribute_names
()attribute_values
()init_state
([init_params])Initialize the solver state.
l2_optimality_error
(params, *args, **kwargs)Computes the L2 optimality error.
run
([init_params])Runs the optimization loop.
update
(params, state, *args, **kwargs)Performs one iteration of the bisection solver.
Attributes
has_aux
- init_state(init_params=None, *args, **kwargs)[source]
Initialize the solver state.
- Parameters
init_params (
Optional
[Any
]) – ignored, we use 0.5 * (state.high + state.low) instead.*args – additional positional arguments to be passed to
optimality_fun
.**kwargs – additional keyword arguments to be passed to
optimality_fun
.
- Return type
BisectionState
- Returns
state
- l2_optimality_error(params, *args, **kwargs)
Computes the L2 optimality error.
- run(init_params=None, *args, **kwargs)[source]
Runs the optimization loop.
- Parameters
init_params (
Optional
[Any
]) – pytree containing the initial parameters.*args – additional positional arguments to be passed to the update method.
**kwargs – additional keyword arguments to be passed to the update method.
- Return type
OptStep
- Returns
(params, state)