jaxmat.tensors package#
Submodules#
- jaxmat.tensors.generic_tensors module
- jaxmat.tensors.linear_algebra module
- jaxmat.tensors.mappings module
- jaxmat.tensors.symmetry_classes module
- jaxmat.tensors.tensor_utils module
- jaxmat.tensors.utils module
Module contents#
jaxmat.tensors — tensor algebra for solid mechanics.
- class Tensor#
Bases:
ModuleEmpty marker base class for all jaxmat tensor objects.
Provides a single type to test against with
isinstance(x, Tensor)without coupling to any specific rank or symmetry class.Both rank-2 classes (
Tensor2,SymmetricTensor2) and rank-4 classes (Tensor4,SymmetricTensor4, and the symmetry-reduced subclasses) inherit from this marker.
- class SymmetricTensor2#
Bases:
Tensor2Symmetric second-rank tensor in 3-D.
Stored as a 6-component Kelvin-Mandel array
\[\{T\} = [T_{11},\, T_{22},\, T_{33},\, \sqrt{2}\,T_{12},\, \sqrt{2}\,T_{13},\, \sqrt{2}\,T_{23}]\]The \(\sqrt{2}\) scaling makes the Kelvin basis orthonormal so that the dot product of two Kelvin vectors equals the double contraction of the corresponding tensors: \(\{T\} \cdot \{S\} = \mathbf{T} : \mathbf{S}\).
- Parameters:
tensor (array_like, shape
(..., 3, 3), optional) – Dense symmetric tensor.array (array_like, shape
(..., 6), optional) – Pre-built Kelvin array.
Notes
The
@operator returnsTensor2because the product of two symmetric matrices is not symmetric in general.double_contract()operates directly on the Kelvin arrays (dot product), with no dense intermediate.Batch dimensions and
jacfwdcompatibility: equinox may set_arraydirectly to the tangent module during forward-mode AD, bypassing__init__. Thetensorproperty uses_raw_array()to unwrap the value safely in that case.- property T: SymmetricTensor2#
Transpose — returns
selfsince \(\mathbf{T} = \mathbf{T}^{\mathsf{T}}\).- Return type:
- base_array_shape = (6,)#
- double_contract(other)[source]#
Double contraction \(\mathbf{T} : \mathbf{S} = T_{ij} S_{ij}\).
When
otheris also aSymmetricTensor2, the contraction reduces to a Kelvin dot product \(\{T\} \cdot \{S\}\) with no dense intermediate.- Parameters:
other (SymmetricTensor2 or array_like, shape
(..., 3, 3))- Returns:
Scalar (or batch of scalars).
- Return type:
jax.Array
- property tensor: Array#
Dense symmetric (3, 3) tensor.
Reconstructed by scattering the Kelvin components into the upper triangle and symmetrising.
_raw_array()is applied to_arrayfirst to handle the case where equinox sets_arrayto the tangent module object duringjacfwd(bypassing__init__).- Returns:
Shape
(..., 3, 3).- Return type:
jax.Array
- class Tensor2#
Bases:
TensorFull (non-symmetric) second-rank tensor in 3-D.
Stored as a 9-component Kelvin array in the ordering \([T_{11}, T_{22}, T_{33}, T_{12}, T_{21}, T_{13}, T_{31}, T_{23}, T_{32}]\).
- Parameters:
tensor (array_like, shape
(..., 3, 3), optional) – Dense tensor representation.array (array_like, shape
(..., 9), optional) – Pre-built Kelvin array. Passing anotherTensor2is accepted; its_arrayfield is used directly.
Notes
Exactly one of
tensororarraymay be provided. If neither is given the tensor is initialized to zero.Batch dimensions precede the storage dimension: a batch of \(N\) tensors has
array.shape == (N, 9)andtensor.shape == (N, 3, 3).The
@operator performs dense matrix composition \((\mathbf{T} \cdot \mathbf{S})_{ik} = T_{ij} S_{jk}\) and always returns aTensor2, regardless of the symmetry of the operands. Usedouble_contract()for \(\mathbf{T}:\mathbf{S}\) andmatvec()for \(\mathbf{T} \cdot \mathbf{v}\).- property T: Tensor2#
Transpose \(T^{\mathsf{T}}_{ij} = T_{ji}\).
Computed via a single gather on the Kelvin array; no dense intermediate is constructed.
- Return type:
- property array: Array#
9-component Kelvin array.
- Returns:
Shape
(..., 9).- Return type:
jax.Array
- array_rank = 1#
- base_array_shape = (9,)#
- base_tensor_shape = (3, 3)#
- property batch_shape: tuple#
Leading batch dimensions
(...).
- property det: Array#
Determinant \(\det(\mathbf{T})\).
- Return type:
jax.Array
- dim = 3#
- double_contract(other)[source]#
Double contraction \(\mathbf{T} : \mathbf{S} = T_{ij} S_{ij}\).
- Parameters:
other (Tensor2 or array_like, shape
(..., 3, 3))- Returns:
Scalar (or batch of scalars).
- Return type:
jax.Array
- matvec(v)[source]#
Matrix-vector product \((\mathbf{T} \cdot \mathbf{v})_i = T_{ij} v_j\).
- Parameters:
v (array_like, shape
(..., 3))- Returns:
Shape
(..., 3).- Return type:
jax.Array
- rank = 2#
- rotate(R)[source]#
Rotate the tensor: \(\mathbf{R} \mathbf{T} \mathbf{R}^{\mathsf{T}}\).
- Parameters:
R (array_like, shape (3, 3)) – Orthogonal rotation matrix.
- Return type:
- property shape: tuple#
Shape of the underlying Kelvin array
(..., 9).
- property skw: Tensor2#
Skew-symmetric part \((\mathbf{T} - \mathbf{T}^{\mathsf{T}}) / 2\).
- Return type:
- property sym: SymmetricTensor2#
Symmetric part \((\mathbf{T} + \mathbf{T}^{\mathsf{T}}) / 2\).
- Return type:
- property tensor: Array#
Dense tensor representation.
Reconstructed via a single gather on
_arrayfollowed by a reshape — no scatter, no zero allocation.- Returns:
Shape
(..., 3, 3).- Return type:
jax.Array
- property tensor_shape: tuple#
Shape of the dense tensor
(..., 3, 3).
- property tr: Array#
Trace \(\mathrm{tr}(\mathbf{T}) = T_{ii}\).
- Return type:
jax.Array
- class SymmetricTensor4#
Bases:
_AbstractTensor4Fourth-rank tensor with minor symmetries \(C_{ijkl} = C_{jikl} = C_{ijlk} = C_{jilk}\).
Stored as a \((6 \times 6)\) Kelvin-Mandel matrix. The Kelvin scaling ensures that the double contraction \(\mathbb{C} : \boldsymbol{\varepsilon}\) is a plain matrix-vector product on the Kelvin arrays.
- Parameters:
tensor (array_like, shape
(..., 3, 3, 3, 3), optional) – Dense tensor with minor symmetries.array (array_like, shape
(..., 6, 6), optional) – Pre-built Kelvin matrix.
Notes
The
@operator performs Kelvin-space products:@SymmetricTensor2→ \((6,6) \cdot (6,)\) →SymmetricTensor2.@SymmetricTensor4→ \((6,6) \cdot (6,6)\) →SymmetricTensor4.
Use
to_symmetric()on_AbstractTensor4subclasses to materialise them into this form before mixed-class operations.- base_array_shape = (6, 6)#
- classmethod identity()[source]#
Symmetric fourth-rank identity \(\mathbb{I}^s_{ijkl} = \tfrac{1}{2}(\delta_{ik}\delta_{jl} + \delta_{il}\delta_{jk})\).
- Return type:
- property tensor: Array#
Dense (3, 3, 3, 3) tensor with minor symmetries enforced.
Reconstructed via a 2-D gather on
_array / _S4_W.- Returns:
Shape
(..., 3, 3, 3, 3).- Return type:
jax.Array
- class Tensor4#
Bases:
_AbstractTensor4Full (non-minor-symmetric) fourth-rank tensor in 3-D.
Stored as a \((9 \times 9)\) Kelvin matrix corresponding to the full double-index ordering \([T_{11}, T_{22}, T_{33}, T_{12}, T_{21}, T_{13}, T_{31}, T_{23}, T_{32}]\).
- Parameters:
tensor (array_like, shape
(..., 3, 3, 3, 3), optional) – Dense tensor representation.array (array_like, shape
(..., 9, 9), optional) – Pre-built Kelvin matrix.
Notes
The
@operator denotes double contraction:- base_array_shape = (9, 9)#
- classmethod identity()[source]#
Fourth-rank identity \(\mathbb{I}_{ijkl} = \delta_{ik}\delta_{jl}\).
- Return type:
- property tensor: Array#
Dense (3, 3, 3, 3) tensor via a 2-D gather on the Kelvin matrix.
- Returns:
Shape
(..., 3, 3, 3, 3).- Return type:
jax.Array
- class CubicTensor4#
Bases:
AbstractStructuredTensor4Cubic-symmetric fourth-rank tensor.
Represented in the basis of three mutually orthogonal projectors:
\[\mathbb{C} = 3\kappa\,\mathbb{J} + 2\mu_a\,\mathbb{K}_a + 2\mu_b\,\mathbb{K}_b\]where \(\mathbb{J}\) is the volumetric projector, \(\mathbb{K}_a\) projects onto the diagonal deviatoric part, and \(\mathbb{K}_b\) projects onto the off-diagonal shear part.
- Parameters:
coeffs (array_like, shape
(..., 3), optional) – Direct basis coefficients \([3\kappa, 2\mu_a, 2\mu_b]\).kappa (float or array_like, optional) – Cubic bulk modulus \(\kappa\).
mua (float or array_like, optional) – Diagonal deviatoric modulus \(\mu_a\).
mub (float or array_like, optional) – Off-diagonal shear modulus \(\mu_b\).
Notes
When \(\mu_a = \mu_b = \mu\) the tensor reduces to the isotropic case \(\mathbb{C} = 3\kappa\mathbb{J} + 2\mu\mathbb{K}\). Coefficients are stored as \(\{3\kappa, 2\mu_a, 2\mu_b\}\).
- J = SymmetricTensor4(shape=(6, 6))#
- Ka = SymmetricTensor4(shape=(6, 6))#
- Kb = SymmetricTensor4(shape=(6, 6))#
- property inv: CubicTensor4#
Inverse \(\mathbb{C}^{-1}\).
Because the three projectors are mutually orthogonal, the inverse is coefficient-wise: \(c_\alpha^{-1}\) for each \(\alpha\).
- Return type:
- classmethod project(C)[source]#
Project a
SymmetricTensor4onto the cubic subspace.- Parameters:
C (
SymmetricTensor4)- Return type:
- class IsotropicTensor4#
Bases:
AbstractStructuredTensor4Isotropic fourth-rank tensor.
Parameterised by bulk and shear moduli and expressed in the orthogonal projector basis \((\mathbb{J}, \mathbb{K})\):
\[\mathbb{C} = 3\kappa\,\mathbb{J} + 2\mu\,\mathbb{K}\]- Parameters:
coeffs (array_like, shape
(..., 2), optional) – Direct basis coefficients \([3\kappa, 2\mu]\).kappa (float or array_like, optional) – Bulk modulus \(\kappa\).
mu (float or array_like, optional) – Shear modulus \(\mu\).
Notes
Exactly one of
coeffsor the pair(kappa, mu)must be provided. Coefficients are stored as \(\{3\kappa, 2\mu\}\) so that the projector expansion \(c_1\mathbb{J}+c_2\mathbb{K}\) uses the standard elasticity identities \(\mathbb{J}:\mathbf{A} = \tfrac{1}{3}\operatorname{tr}(\mathbf{A})\mathbf{I}\) and \(\mathbb{K}:\mathbf{A} = \operatorname{dev}(\mathbf{A})\).Batched use (one tensor per material point) is supported by passing array-valued
kappaandmuof shape(...,).- J = SymmetricTensor4(shape=(6, 6))#
- K = SymmetricTensor4(shape=(6, 6))#
- property inv: IsotropicTensor4#
Inverse \(\mathbb{C}^{-1}\).
Because \(\mathbb{J}\) and \(\mathbb{K}\) are orthogonal projectors, the inverse is obtained by inverting each coefficient: \(\mathbb{C}^{-1} = \tfrac{1}{3\kappa}\mathbb{J} + \tfrac{1}{2\mu}\mathbb{K}\).
- Return type:
- property kappa: Array#
Bulk modulus \(\kappa = c_1 / 3\).
- property mu: Array#
Shear modulus \(\mu = c_2 / 2\).
- classmethod project(C)[source]#
Project a
SymmetricTensor4onto the isotropic subspace.Extracts the isotropic part by computing \(\kappa = \tfrac{1}{3}\mathbb{C}::\mathbb{J}\) and \(\mu = \tfrac{1}{10}\mathbb{C}::\mathbb{K}\).
- Parameters:
C (
SymmetricTensor4)- Return type:
- class TransverseIsotropicTensor4#
Bases:
AbstractStructuredTensor4Transversely isotropic fourth-rank tensor in the Walpole basis.
Defined with respect to a unit symmetry axis \(\hat{\mathbf{a}}\) and expanded in the six Walpole basis tensors \(\mathbb{E}_1, \mathbb{E}_2, \mathbb{E}_3, \mathbb{E}_4, \mathbb{F}, \mathbb{G}\):
\[\mathbb{C} = \sum_{\alpha=1}^6 c_\alpha \mathbb{P}_\alpha\]- Parameters:
axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\). Need not be pre-normalised.
coeffs (array_like, shape
(..., 6)) – Walpole basis coefficients \([c_1, c_2, c_3, c_4, c_5, c_6]\).
Notes
Inversion uses the \(2\times 2\) block structure of \(\mathbb{E}_1\ldots\mathbb{E}_4\) and scalar inversion for \(\mathbb{F}\) and \(\mathbb{G}\).
- property inv: TransverseIsotropicTensor4#
Inverse operator within the transverse-isotropic subspace.
The sub-block \([c_1, c_3; c_4, c_2]\) corresponding to \(\mathbb{E}_1\ldots\mathbb{E}_4\) is inverted as a \(2\times 2\) matrix, while \(c_5\) and \(c_6\) (for \(\mathbb{F}\) and \(\mathbb{G}\)) are inverted individually.
- Return type:
- classmethod project(axis, C)[source]#
Project a
SymmetricTensor4onto the transverse-isotropic subspace.- Parameters:
axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\).
C (
SymmetricTensor4)
- Return type:
- axis: jax.Array#
- cubic_projectors()#
Construct cubic-symmetry fourth-rank projectors.
- Returns:
J (
SymmetricTensor4) – Volumetric projector.Ka (
SymmetricTensor4) – Diagonal deviatoric projector (cubic anisotropic part of the diagonal).Kb (
SymmetricTensor4) – Off-diagonal shear projector.
Notes
The three projectors are mutually orthogonal and partition the identity: \(\mathbb{J}+\mathbb{K}_a+\mathbb{K}_b = \mathbb{I}^s\).
- isotropic_projectors()#
Construct isotropic fourth-rank projectors.
- Returns:
J (
SymmetricTensor4) – Volumetric projector \(\mathbb{J}_{ijkl} = \frac{1}{3}\delta_{ij}\delta_{kl}\).K (
SymmetricTensor4) – Deviatoric projector \(\mathbb{K} = \mathbb{I}^s - \mathbb{J}\).
Notes
The projectors satisfy \(\mathbb{J}:\mathbb{J}=\mathbb{J}\), \(\mathbb{K}:\mathbb{K}=\mathbb{K}\), \(\mathbb{J}:\mathbb{K}=0\), and \(\mathbb{J}+\mathbb{K}=\mathbb{I}^s\).
- transverse_isotropic_projectors(axis)#
Construct transverse-isotropic (Walpole) fourth-rank projectors.
- Parameters:
axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\).
- Returns:
E1, E2, E3, E4, F, G – Six Walpole basis tensors spanning the transverse-isotropic subspace. E1 and E2 are projectors; E3, E4 are their cross terms; F and G are the remaining orthogonal complements.
- Return type:
Notes
The Walpole basis is not orthogonal in the usual sense: composition follows a \(2\times 2\) block rule for E1..E4 and scalar inversion for F and G. See
TransverseIsotropicTensor4for the inversion formula.
- axl(A)[source]#
Axial vector of a skew-symmetric tensor.
The axial vector \(\mathbf{w}\) associated with \(\mathbf{W} = \operatorname{skw}(\mathbf{A})\) satisfies \(\mathbf{W}\mathbf{v} = \mathbf{w} \times \mathbf{v}\) and is given by \(w_i = -\tfrac{1}{2}\,\varepsilon_{ijk}\,W_{jk}\).
- Parameters:
A (Tensor2)
- Returns:
Shape
(..., 3).- Return type:
jax.Array
- dev(A)[source]#
Deviatoric part \(\mathbf{A} - \tfrac{1}{3}\operatorname{tr}(\mathbf{A})\,\mathbf{I}\).
- Parameters:
A (Tensor2 or SymmetricTensor2)
- Return type:
- norm(A)[source]#
Frobenius norm \(\|\mathbf{A}\| = \sqrt{\mathbf{A}:\mathbf{A}}\).
For
SymmetricTensor2operands the double contraction is evaluated as a Kelvin dot product with no dense intermediate.- Parameters:
A (Tensor2 or SymmetricTensor2)
- Returns:
Scalar (or batch of scalars).
- Return type:
jax.Array
- polar(F, mode='RU')[source]#
Polar decomposition \(\mathbf{F} = \mathbf{R}\mathbf{U}\) or \(\mathbf{F} = \mathbf{V}\mathbf{R}\).
- Parameters:
F (Tensor2) – Deformation gradient.
mode ({"RU", "VR"}, optional) – Selects the right polar decomposition (
"RU", default) or the left polar decomposition ("VR").
- Returns:
(R, U)formode="RU"whereRis aTensor2(rotation) andUaSymmetricTensor2(right stretch), or(V, R)formode="VR"whereVis aSymmetricTensor2(left stretch).- Return type:
tuple
- stretch_tensor(F)[source]#
Right stretch tensor \(\mathbf{U} = (\mathbf{F}^{\mathsf{T}}\mathbf{F})^{1/2}\).
Convenience wrapper around
polar().- Parameters:
F (Tensor2) – Deformation gradient.
- Return type:
- sym(A)[source]#
Symmetric part \((\mathbf{A} + \mathbf{A}^{\mathsf{T}}) / 2\).
- Parameters:
A (Tensor2)
- Return type:
- tr(A)[source]#
Trace \(\operatorname{tr}(\mathbf{A}) = A_{ii}\).
- Parameters:
A (Tensor2)
- Returns:
Scalar (or batch of scalars).
- Return type:
jax.Array
- vol(A)[source]#
Volumetric (spherical) part \(\tfrac{1}{3}\operatorname{tr}(\mathbf{A})\,\mathbf{I}\).
Complement of
dev():vol(A) + dev(A) == Afor symmetricA.- Parameters:
A (Tensor2 or SymmetricTensor2)
- Return type:
- von_mises(sig)[source]#
Von Mises equivalent stress.
\[\sigma_\text{VM} = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}, \qquad \mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\]- Parameters:
sig (Tensor2 or SymmetricTensor2) – Cauchy stress tensor.
- Returns:
Scalar (or batch of scalars).
- Return type:
jax.Array
- det33(A)[source]#
Determinant \(\det(\mathbf{A})\) of a \(3 \times 3\) matrix.
Evaluated via the explicit Sarrus formula, avoiding
jnp.linalg.detoverhead for the fixed-size case.- Parameters:
A (array_like, shape (3, 3))
- Returns:
Scalar determinant.
- Return type:
jax.Array
- expm(A)[source]#
Matrix exponential \(\exp(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3))
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- inv_sqrtm(A)[source]#
Inverse square root \(\mathbf{A}^{-1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.
Computed jointly with
sqrtm()via the same closed-form expression, so both are available at identical cost.- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Return type:
Array- Returns:
jax.Array – Shape (3, 3).
.. admonition:: References – :class: seealso
Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.
- isotropic_function(fun, A)[source]#
Isotropic matrix function \(f(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.
Evaluates the spectral decomposition \(f(\mathbf{A}) = \sum_{i=1}^{3} f(\lambda_i)\,\mathbf{n}_i \otimes \mathbf{n}_i\) where \(\lambda_i\) are the eigenvalues and \(\mathbf{n}_i\) the corresponding eigenvectors of \(\mathbf{A}\).
- Parameters:
fun (callable) – Scalar function \(f : \mathbb{R} \to \mathbb{R}\) applied to each eigenvalue.
A (array_like, shape (3, 3)) – Real symmetric matrix.
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- logm(A)[source]#
Matrix logarithm \(\log(\mathbf{A})\) of a symmetric positive definite \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- main_invariants(A)[source]#
Main (trace-power) invariants \((J_1, J_2, J_3)\) of a \(3 \times 3\) matrix.
\[J_k = \operatorname{tr}(\mathbf{A}^k), \quad k = 1, 2, 3\]- Parameters:
A (array_like, shape (3, 3))
- Returns:
J1, J2, J3 – Three scalar invariants.
- Return type:
jax.Array
- powm(A, m)[source]#
Matrix power \(\mathbf{A}^m\) of a symmetric \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3))
m (float) – Exponent.
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- pq_invariants(sig)[source]#
Hydrostatic pressure \(p\) and deviatoric equivalent stress \(q\).
Commonly used in soil mechanics and pressure-sensitive plasticity.
\[p = -\tfrac{1}{3}\operatorname{tr}(\boldsymbol{\sigma}), \qquad q = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}\]where \(\mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\).
- Parameters:
sig (array_like, shape (3, 3)) – Cauchy stress tensor.
- Return type:
tuple[Array,Array]- Returns:
p (jax.Array) – Mean pressure (positive in compression).
q (jax.Array) – Von Mises equivalent stress.
- principal_invariants(A)[source]#
Principal invariants \((I_1, I_2, I_3)\) of a \(3 \times 3\) matrix \(\mathbf{A}\).
\[I_1 = \operatorname{tr}(\mathbf{A}), \quad I_2 = \tfrac{1}{2}\bigl(\operatorname{tr}(\mathbf{A})^2 - \operatorname{tr}(\mathbf{A}^2)\bigr), \quad I_3 = \det(\mathbf{A})\]- Parameters:
A (array_like, shape (3, 3))
- Returns:
I1, I2, I3 – Three scalar invariants.
- Return type:
jax.Array
- sqrtm(A)[source]#
Matrix square root \(\mathbf{A}^{1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.
Uses the closed-form expression of Simo & Hughes (1998) based on the principal invariants of \(\mathbf{A}^{1/2}\). Accepts any object that implements
__jax_array__(e.g. aSymmetricTensor2).- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Return type:
Array- Returns:
jax.Array – Shape (3, 3).
.. admonition:: References – :class: seealso
Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.
- safe_fun(fun, x, norm=None, eps=1e-16)[source]#
Apply a function safely, avoiding evaluation at or near zero.
The input
xis replaced by a small positive sentinelepswhenevernorm(x) <= epsbefore callingfun. The final result is then masked to return zero in that case. This sentinel strategy ensures thatfunis always evaluated at a numerically safe point, so that gradients throughfunremain finite under automatic differentiation.This is consistent with
safe_sqrt():safe_fun(jnp.sqrt, x)produces the same values and gradients assafe_sqrt(x).- Parameters:
fun (callable) – Scalar or array function to apply safely.
x (array_like) – Input value.
norm (callable, optional) – Scalar-valued function of
xused to test proximity to zero. Defaults to the identity (i.e.xitself is the magnitude).eps (float, optional) – Threshold below which
xis considered zero. Defaults to1e-16.
- Returns:
fun(x)wherenorm(x) > eps, otherwise0.- Return type:
jax.Array
Notes
The key property is that the sentinel-substituted input
eps(not0) is passed tofunin the masked branch. This preventsjax.gradfrom encountering undefined derivatives (e.g.1 / (2 sqrt(0))forfun = jnp.sqrt).
- safe_norm(x, eps=1e-16, **kwargs)[source]#
Wrapper around
optax.safe_normthat computes a numerically stable norm.This function prevents numerical instability when computing vector norms for small magnitudes by internally applying a stability threshold.
- Parameters:
x (array-like) – Input vector or tensor.
eps (float, optional) – Small constant added for numerical stability. Defaults to
1e-16.**kwargs – Additional arguments passed to
optax.safe_norm.
- Returns:
The numerically stable norm of
x.- Return type:
array-like
- safe_sqrt(x, eps=1e-16)[source]#
Computes a numerically safe square root.
Ensures the argument to the square root is greater than eps to avoid taking the square root of zero or negative values, which could cause instability or NaNs.
- Parameters:
x (array-like) – Input array or tensor.
eps (float, optional) – Minimum threshold for x before taking the square root. Defaults to 1e-16.
- Returns:
The square root of x for x > eps, otherwise eps.
- Return type:
array-like