jaxopt.Bisection

class jaxopt.Bisection(optimality_fun, lower, upper, maxiter=30, tol=1e-05, check_bracket=True, verbose=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')[source]

One-dimensional root finding using bisection.

Parameters
  • optimality_fun (Callable) –

  • lower (float) –

  • upper (float) –

  • maxiter (int) –

  • tol (float) –

  • check_bracket (bool) –

  • verbose (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

optimality_fun

a function optimality_fun(x, *args, **kwargs) where x is a 1d variable. The function should have opposite signs when evaluated at lower and at upper.

Type

Callable

lower

the lower end of the bracketing interval.

Type

float

upper

the upper end of the bracketing interval.

Type

float

maxiter

maximum number of iterations.

Type

int

tol

tolerance.

Type

float

check_bracket

whether to check correctness of the bracketing interval. If True, the method run cannot be jitted.

Type

bool

implicit_diff_solve

the linear system solver to use.

Type

Optional[Callable]

verbose

whether to print error on every iteration or not. Warning: verbose=True will automatically disable jit.

Type

bool

jit

whether to JIT-compile the bisection loop (default: “auto”).

Type

Union[str, bool]

unroll

whether to unroll the bisection loop (default: “auto”).

Type

Union[str, bool]

__init__(optimality_fun, lower, upper, maxiter=30, tol=1e-05, check_bracket=True, verbose=False, implicit_diff_solve=None, has_aux=False, jit='auto', unroll='auto')
Parameters
  • optimality_fun (Callable) –

  • lower (float) –

  • upper (float) –

  • maxiter (int) –

  • tol (float) –

  • check_bracket (bool) –

  • verbose (bool) –

  • implicit_diff_solve (Optional[Callable]) –

  • has_aux (bool) –

  • jit (Union[str, bool]) –

  • unroll (Union[str, bool]) –

Return type

None

Methods

__init__(optimality_fun, lower, upper[, ...])

attribute_names()

attribute_values()

init_state([init_params])

Initialize the solver state.

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run([init_params])

Runs the optimization loop.

update(params, state, *args, **kwargs)

Performs one iteration of the bisection solver.

Attributes

check_bracket

has_aux

implicit_diff_solve

jit

maxiter

tol

unroll

verbose

optimality_fun

lower

upper

init_state(init_params=None, *args, **kwargs)[source]

Initialize the solver state.

Parameters
  • init_params (Optional[Any]) – ignored, we use 0.5 * (state.high + state.low) instead.

  • *args – additional positional arguments to be passed to optimality_fun.

  • **kwargs – additional keyword arguments to be passed to optimality_fun.

Return type

BisectionState

Returns

state

l2_optimality_error(params, *args, **kwargs)

Computes the L2 optimality error.

run(init_params=None, *args, **kwargs)[source]

Runs the optimization loop.

Parameters
  • init_params (Optional[Any]) – pytree containing the initial parameters.

  • *args – additional positional arguments to be passed to the update method.

  • **kwargs – additional keyword arguments to be passed to the update method.

Return type

OptStep

Returns

(params, state)

update(params, state, *args, **kwargs)[source]

Performs one iteration of the bisection solver.

Parameters
  • params – ignored, we use 0.5 * (state.high + state.low) instead.

  • state (NamedTuple) – named tuple containing the solver state.

Return type

OptStep

Returns

(params, state)