Source code for jaxmat.tensors.utils
import jax.numpy as jnp
import optax
[docs]
def safe_fun(fun, x, norm=None, eps=1e-16):
r"""
Apply a function safely, avoiding evaluation at or near zero.
The input ``x`` is replaced by a small positive sentinel ``eps`` whenever
``norm(x) <= eps`` before calling ``fun``. The final result is then
masked to return zero in that case. This sentinel strategy ensures that
``fun`` is always evaluated at a numerically safe point, so that gradients
through ``fun`` remain finite under automatic differentiation.
This is consistent with :func:`safe_sqrt`:
``safe_fun(jnp.sqrt, x)`` produces the same values and gradients as
``safe_sqrt(x)``.
Parameters
----------
fun : callable
Scalar or array function to apply safely.
x : array_like
Input value.
norm : callable, optional
Scalar-valued function of ``x`` used to test proximity to zero.
Defaults to the identity (i.e. ``x`` itself is the magnitude).
eps : float, optional
Threshold below which ``x`` is considered zero. Defaults to ``1e-16``.
Returns
-------
jax.Array
``fun(x)`` where ``norm(x) > eps``, otherwise ``0``.
Notes
-----
The key property is that the *sentinel-substituted* input ``eps`` (not
``0``) is passed to ``fun`` in the masked branch. This prevents
``jax.grad`` from encountering undefined derivatives (e.g.
``1 / (2 sqrt(0))`` for ``fun = jnp.sqrt``).
"""
if norm is None:
def norm(x):
return x
is_nonzero = norm(x) > eps
# Use eps as sentinel (not 0) so fun is always evaluated at a safe point.
safe_x = jnp.where(is_nonzero, x, eps)
return jnp.where(is_nonzero, fun(safe_x), 0 * fun(safe_x))
[docs]
def safe_sqrt(x, eps=1e-16):
"""
Computes a numerically safe square root.
Ensures the argument to the square root is greater than `eps`
to avoid taking the square root of zero or negative values,
which could cause instability or NaNs.
Parameters
----------
x : array-like
Input array or tensor.
eps : float, optional
Minimum threshold for `x` before taking the square root. Defaults to 1e-16.
Returns
--------
array-like
The square root of `x` for `x > eps`, otherwise `eps`.
"""
nonzero_x = jnp.where(x > eps, x, eps)
return jnp.where(x > eps, jnp.sqrt(nonzero_x), eps)
[docs]
def safe_norm(x, eps=1e-16, **kwargs):
"""
Wrapper around ``optax.safe_norm`` that computes a numerically stable norm.
This function prevents numerical instability when computing vector norms
for small magnitudes by internally applying a stability threshold.
Parameters
----------
x : array-like
Input vector or tensor.
eps : float, optional
Small constant added for numerical stability. Defaults to ``1e-16``.
**kwargs:
Additional arguments passed to ``optax.safe_norm``.
Returns
-------
array-like
The numerically stable norm of ``x``.
"""
return optax.safe_norm(x, eps, **kwargs)
[docs]
def FischerBurmeister(x, y):
r"""
Computes the scalar Fischer-Burmeister function.
The Fischer-Burmeister function is defined as:
$$\Phi(x, y) = x + y - \sqrt{x^2 + y^2}$$
and is commonly used in complementarity problem formulations to provide
a semi-smooth reformulation of the complementarity conditions
$$x \geq 0, y \geq 0, xy = 0$$.
"""
return x + y - safe_sqrt(x**2 + y**2)