jaxmat.tensors.utils module#
- safe_fun(fun, x, norm=None, eps=1e-16)[source]#
Apply a function safely, avoiding evaluation at or near zero.
The input
xis replaced by a small positive sentinelepswhenevernorm(x) <= epsbefore callingfun. The final result is then masked to return zero in that case. This sentinel strategy ensures thatfunis always evaluated at a numerically safe point, so that gradients throughfunremain finite under automatic differentiation.This is consistent with
safe_sqrt():safe_fun(jnp.sqrt, x)produces the same values and gradients assafe_sqrt(x).- Parameters:
fun (callable) – Scalar or array function to apply safely.
x (array_like) – Input value.
norm (callable, optional) – Scalar-valued function of
xused to test proximity to zero. Defaults to the identity (i.e.xitself is the magnitude).eps (float, optional) – Threshold below which
xis considered zero. Defaults to1e-16.
- Returns:
fun(x)wherenorm(x) > eps, otherwise0.- Return type:
jax.Array
Notes
The key property is that the sentinel-substituted input
eps(not0) is passed tofunin the masked branch. This preventsjax.gradfrom encountering undefined derivatives (e.g.1 / (2 sqrt(0))forfun = jnp.sqrt).
- safe_sqrt(x, eps=1e-16)[source]#
Computes a numerically safe square root.
Ensures the argument to the square root is greater than eps to avoid taking the square root of zero or negative values, which could cause instability or NaNs.
- Parameters:
x (array-like) – Input array or tensor.
eps (float, optional) – Minimum threshold for x before taking the square root. Defaults to 1e-16.
- Returns:
The square root of x for x > eps, otherwise eps.
- Return type:
array-like
- safe_norm(x, eps=1e-16, **kwargs)[source]#
Wrapper around
optax.safe_normthat computes a numerically stable norm.This function prevents numerical instability when computing vector norms for small magnitudes by internally applying a stability threshold.
- Parameters:
x (array-like) – Input vector or tensor.
eps (float, optional) – Small constant added for numerical stability. Defaults to
1e-16.**kwargs – Additional arguments passed to
optax.safe_norm.
- Returns:
The numerically stable norm of
x.- Return type:
array-like
- FischerBurmeister(x, y)[source]#
Computes the scalar Fischer-Burmeister function.
The Fischer-Burmeister function is defined as:
\[\Phi(x, y) = x + y - \sqrt{x^2 + y^2}\]and is commonly used in complementarity problem formulations to provide a semi-smooth reformulation of the complementarity conditions
\[x \geq 0, y \geq 0, xy = 0\].