jaxmat.tensors package#

Submodules#

Module contents#

jaxmat.tensors — tensor algebra for solid mechanics.

class Tensor#

Bases: Module

Empty marker base class for all jaxmat tensor objects.

Provides a single type to test against with isinstance(x, Tensor) without coupling to any specific rank or symmetry class.

Both rank-2 classes (Tensor2, SymmetricTensor2) and rank-4 classes (Tensor4, SymmetricTensor4, and the symmetry-reduced subclasses) inherit from this marker.

class SymmetricTensor2#

Bases: Tensor2

Symmetric second-rank tensor in 3-D.

Stored as a 6-component Kelvin-Mandel array

\[\{T\} = [T_{11},\, T_{22},\, T_{33},\, \sqrt{2}\,T_{12},\, \sqrt{2}\,T_{13},\, \sqrt{2}\,T_{23}]\]

The \(\sqrt{2}\) scaling makes the Kelvin basis orthonormal so that the dot product of two Kelvin vectors equals the double contraction of the corresponding tensors: \(\{T\} \cdot \{S\} = \mathbf{T} : \mathbf{S}\).

Parameters:
  • tensor (array_like, shape (..., 3, 3), optional) – Dense symmetric tensor.

  • array (array_like, shape (..., 6), optional) – Pre-built Kelvin array.

Notes

The @ operator returns Tensor2 because the product of two symmetric matrices is not symmetric in general.

double_contract() operates directly on the Kelvin arrays (dot product), with no dense intermediate.

Batch dimensions and jacfwd compatibility: equinox may set _array directly to the tangent module during forward-mode AD, bypassing __init__. The tensor property uses _raw_array() to unwrap the value safely in that case.

property T: SymmetricTensor2#

Transpose — returns self since \(\mathbf{T} = \mathbf{T}^{\mathsf{T}}\).

Return type:

SymmetricTensor2

base_array_shape = (6,)#
double_contract(other)[source]#

Double contraction \(\mathbf{T} : \mathbf{S} = T_{ij} S_{ij}\).

When other is also a SymmetricTensor2, the contraction reduces to a Kelvin dot product \(\{T\} \cdot \{S\}\) with no dense intermediate.

Parameters:

other (SymmetricTensor2 or array_like, shape (..., 3, 3))

Returns:

Scalar (or batch of scalars).

Return type:

jax.Array

property tensor: Array#

Dense symmetric (3, 3) tensor.

Reconstructed by scattering the Kelvin components into the upper triangle and symmetrising. _raw_array() is applied to _array first to handle the case where equinox sets _array to the tangent module object during jacfwd (bypassing __init__).

Returns:

Shape (..., 3, 3).

Return type:

jax.Array

class Tensor2#

Bases: Tensor

Full (non-symmetric) second-rank tensor in 3-D.

Stored as a 9-component Kelvin array in the ordering \([T_{11}, T_{22}, T_{33}, T_{12}, T_{21}, T_{13}, T_{31}, T_{23}, T_{32}]\).

Parameters:
  • tensor (array_like, shape (..., 3, 3), optional) – Dense tensor representation.

  • array (array_like, shape (..., 9), optional) – Pre-built Kelvin array. Passing another Tensor2 is accepted; its _array field is used directly.

Notes

Exactly one of tensor or array may be provided. If neither is given the tensor is initialized to zero.

Batch dimensions precede the storage dimension: a batch of \(N\) tensors has array.shape == (N, 9) and tensor.shape == (N, 3, 3).

The @ operator performs dense matrix composition \((\mathbf{T} \cdot \mathbf{S})_{ik} = T_{ij} S_{jk}\) and always returns a Tensor2, regardless of the symmetry of the operands. Use double_contract() for \(\mathbf{T}:\mathbf{S}\) and matvec() for \(\mathbf{T} \cdot \mathbf{v}\).

property T: Tensor2#

Transpose \(T^{\mathsf{T}}_{ij} = T_{ji}\).

Computed via a single gather on the Kelvin array; no dense intermediate is constructed.

Return type:

Tensor2

property array: Array#

9-component Kelvin array.

Returns:

Shape (..., 9).

Return type:

jax.Array

array_rank = 1#
base_array_shape = (9,)#
base_tensor_shape = (3, 3)#
property batch_shape: tuple#

Leading batch dimensions (...).

property det: Array#

Determinant \(\det(\mathbf{T})\).

Return type:

jax.Array

dim = 3#
double_contract(other)[source]#

Double contraction \(\mathbf{T} : \mathbf{S} = T_{ij} S_{ij}\).

Parameters:

other (Tensor2 or array_like, shape (..., 3, 3))

Returns:

Scalar (or batch of scalars).

Return type:

jax.Array

classmethod identity()[source]#

Second-rank identity \(\delta_{ij}\).

Return type:

Tensor2

property inv: Tensor2#

Inverse \(\mathbf{T}^{-1}\).

Return type:

Tensor2

matvec(v)[source]#

Matrix-vector product \((\mathbf{T} \cdot \mathbf{v})_i = T_{ij} v_j\).

Parameters:

v (array_like, shape (..., 3))

Returns:

Shape (..., 3).

Return type:

jax.Array

rank = 2#
rotate(R)[source]#

Rotate the tensor: \(\mathbf{R} \mathbf{T} \mathbf{R}^{\mathsf{T}}\).

Parameters:

R (array_like, shape (3, 3)) – Orthogonal rotation matrix.

Return type:

Tensor2

property shape: tuple#

Shape of the underlying Kelvin array (..., 9).

property skw: Tensor2#

Skew-symmetric part \((\mathbf{T} - \mathbf{T}^{\mathsf{T}}) / 2\).

Return type:

Tensor2

property sym: SymmetricTensor2#

Symmetric part \((\mathbf{T} + \mathbf{T}^{\mathsf{T}}) / 2\).

Return type:

SymmetricTensor2

property tensor: Array#

Dense tensor representation.

Reconstructed via a single gather on _array followed by a reshape — no scatter, no zero allocation.

Returns:

Shape (..., 3, 3).

Return type:

jax.Array

property tensor_shape: tuple#

Shape of the dense tensor (..., 3, 3).

property tr: Array#

Trace \(\mathrm{tr}(\mathbf{T}) = T_{ii}\).

Return type:

jax.Array

class SymmetricTensor4#

Bases: _AbstractTensor4

Fourth-rank tensor with minor symmetries \(C_{ijkl} = C_{jikl} = C_{ijlk} = C_{jilk}\).

Stored as a \((6 \times 6)\) Kelvin-Mandel matrix. The Kelvin scaling ensures that the double contraction \(\mathbb{C} : \boldsymbol{\varepsilon}\) is a plain matrix-vector product on the Kelvin arrays.

Parameters:
  • tensor (array_like, shape (..., 3, 3, 3, 3), optional) – Dense tensor with minor symmetries.

  • array (array_like, shape (..., 6, 6), optional) – Pre-built Kelvin matrix.

Notes

The @ operator performs Kelvin-space products:

Use to_symmetric() on _AbstractTensor4 subclasses to materialise them into this form before mixed-class operations.

base_array_shape = (6, 6)#
classmethod identity()[source]#

Symmetric fourth-rank identity \(\mathbb{I}^s_{ijkl} = \tfrac{1}{2}(\delta_{ik}\delta_{jl} + \delta_{il}\delta_{jk})\).

Return type:

SymmetricTensor4

property tensor: Array#

Dense (3, 3, 3, 3) tensor with minor symmetries enforced.

Reconstructed via a 2-D gather on _array / _S4_W.

Returns:

Shape (..., 3, 3, 3, 3).

Return type:

jax.Array

to_symmetric()[source]#

Return self — already the fully materialised symmetric form.

Return type:

SymmetricTensor4

class Tensor4#

Bases: _AbstractTensor4

Full (non-minor-symmetric) fourth-rank tensor in 3-D.

Stored as a \((9 \times 9)\) Kelvin matrix corresponding to the full double-index ordering \([T_{11}, T_{22}, T_{33}, T_{12}, T_{21}, T_{13}, T_{31}, T_{23}, T_{32}]\).

Parameters:
  • tensor (array_like, shape (..., 3, 3, 3, 3), optional) – Dense tensor representation.

  • array (array_like, shape (..., 9, 9), optional) – Pre-built Kelvin matrix.

Notes

The @ operator denotes double contraction:

base_array_shape = (9, 9)#
classmethod identity()[source]#

Fourth-rank identity \(\mathbb{I}_{ijkl} = \delta_{ik}\delta_{jl}\).

Return type:

Tensor4

property tensor: Array#

Dense (3, 3, 3, 3) tensor via a 2-D gather on the Kelvin matrix.

Returns:

Shape (..., 3, 3, 3, 3).

Return type:

jax.Array

class CubicTensor4#

Bases: AbstractStructuredTensor4

Cubic-symmetric fourth-rank tensor.

Represented in the basis of three mutually orthogonal projectors:

\[\mathbb{C} = 3\kappa\,\mathbb{J} + 2\mu_a\,\mathbb{K}_a + 2\mu_b\,\mathbb{K}_b\]

where \(\mathbb{J}\) is the volumetric projector, \(\mathbb{K}_a\) projects onto the diagonal deviatoric part, and \(\mathbb{K}_b\) projects onto the off-diagonal shear part.

Parameters:
  • coeffs (array_like, shape (..., 3), optional) – Direct basis coefficients \([3\kappa, 2\mu_a, 2\mu_b]\).

  • kappa (float or array_like, optional) – Cubic bulk modulus \(\kappa\).

  • mua (float or array_like, optional) – Diagonal deviatoric modulus \(\mu_a\).

  • mub (float or array_like, optional) – Off-diagonal shear modulus \(\mu_b\).

Notes

When \(\mu_a = \mu_b = \mu\) the tensor reduces to the isotropic case \(\mathbb{C} = 3\kappa\mathbb{J} + 2\mu\mathbb{K}\). Coefficients are stored as \(\{3\kappa, 2\mu_a, 2\mu_b\}\).

J = SymmetricTensor4(shape=(6, 6))#
Ka = SymmetricTensor4(shape=(6, 6))#
Kb = SymmetricTensor4(shape=(6, 6))#
property inv: CubicTensor4#

Inverse \(\mathbb{C}^{-1}\).

Because the three projectors are mutually orthogonal, the inverse is coefficient-wise: \(c_\alpha^{-1}\) for each \(\alpha\).

Return type:

CubicTensor4

classmethod project(C)[source]#

Project a SymmetricTensor4 onto the cubic subspace.

Parameters:

C (SymmetricTensor4)

Return type:

CubicTensor4

class IsotropicTensor4#

Bases: AbstractStructuredTensor4

Isotropic fourth-rank tensor.

Parameterised by bulk and shear moduli and expressed in the orthogonal projector basis \((\mathbb{J}, \mathbb{K})\):

\[\mathbb{C} = 3\kappa\,\mathbb{J} + 2\mu\,\mathbb{K}\]
Parameters:
  • coeffs (array_like, shape (..., 2), optional) – Direct basis coefficients \([3\kappa, 2\mu]\).

  • kappa (float or array_like, optional) – Bulk modulus \(\kappa\).

  • mu (float or array_like, optional) – Shear modulus \(\mu\).

Notes

Exactly one of coeffs or the pair (kappa, mu) must be provided. Coefficients are stored as \(\{3\kappa, 2\mu\}\) so that the projector expansion \(c_1\mathbb{J}+c_2\mathbb{K}\) uses the standard elasticity identities \(\mathbb{J}:\mathbf{A} = \tfrac{1}{3}\operatorname{tr}(\mathbf{A})\mathbf{I}\) and \(\mathbb{K}:\mathbf{A} = \operatorname{dev}(\mathbf{A})\).

Batched use (one tensor per material point) is supported by passing array-valued kappa and mu of shape (...,).

J = SymmetricTensor4(shape=(6, 6))#
K = SymmetricTensor4(shape=(6, 6))#
property inv: IsotropicTensor4#

Inverse \(\mathbb{C}^{-1}\).

Because \(\mathbb{J}\) and \(\mathbb{K}\) are orthogonal projectors, the inverse is obtained by inverting each coefficient: \(\mathbb{C}^{-1} = \tfrac{1}{3\kappa}\mathbb{J} + \tfrac{1}{2\mu}\mathbb{K}\).

Return type:

IsotropicTensor4

property kappa: Array#

Bulk modulus \(\kappa = c_1 / 3\).

property mu: Array#

Shear modulus \(\mu = c_2 / 2\).

classmethod project(C)[source]#

Project a SymmetricTensor4 onto the isotropic subspace.

Extracts the isotropic part by computing \(\kappa = \tfrac{1}{3}\mathbb{C}::\mathbb{J}\) and \(\mu = \tfrac{1}{10}\mathbb{C}::\mathbb{K}\).

Parameters:

C (SymmetricTensor4)

Return type:

IsotropicTensor4

rotate(R)[source]#

Isotropic tensors are invariant under all rotations.

Parameters:

R (array_like, shape (3, 3)) – Orthogonal rotation matrix (unused).

Returns:

self unchanged.

Return type:

IsotropicTensor4

class TransverseIsotropicTensor4#

Bases: AbstractStructuredTensor4

Transversely isotropic fourth-rank tensor in the Walpole basis.

Defined with respect to a unit symmetry axis \(\hat{\mathbf{a}}\) and expanded in the six Walpole basis tensors \(\mathbb{E}_1, \mathbb{E}_2, \mathbb{E}_3, \mathbb{E}_4, \mathbb{F}, \mathbb{G}\):

\[\mathbb{C} = \sum_{\alpha=1}^6 c_\alpha \mathbb{P}_\alpha\]
Parameters:
  • axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\). Need not be pre-normalised.

  • coeffs (array_like, shape (..., 6)) – Walpole basis coefficients \([c_1, c_2, c_3, c_4, c_5, c_6]\).

Notes

Inversion uses the \(2\times 2\) block structure of \(\mathbb{E}_1\ldots\mathbb{E}_4\) and scalar inversion for \(\mathbb{F}\) and \(\mathbb{G}\).

property inv: TransverseIsotropicTensor4#

Inverse operator within the transverse-isotropic subspace.

The sub-block \([c_1, c_3; c_4, c_2]\) corresponding to \(\mathbb{E}_1\ldots\mathbb{E}_4\) is inverted as a \(2\times 2\) matrix, while \(c_5\) and \(c_6\) (for \(\mathbb{F}\) and \(\mathbb{G}\)) are inverted individually.

Return type:

TransverseIsotropicTensor4

classmethod project(axis, C)[source]#

Project a SymmetricTensor4 onto the transverse-isotropic subspace.

Parameters:
  • axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\).

  • C (SymmetricTensor4)

Return type:

TransverseIsotropicTensor4

axis: jax.Array#
cubic_projectors()#

Construct cubic-symmetry fourth-rank projectors.

Returns:

Notes

The three projectors are mutually orthogonal and partition the identity: \(\mathbb{J}+\mathbb{K}_a+\mathbb{K}_b = \mathbb{I}^s\).

isotropic_projectors()#

Construct isotropic fourth-rank projectors.

Returns:

  • J (SymmetricTensor4) – Volumetric projector \(\mathbb{J}_{ijkl} = \frac{1}{3}\delta_{ij}\delta_{kl}\).

  • K (SymmetricTensor4) – Deviatoric projector \(\mathbb{K} = \mathbb{I}^s - \mathbb{J}\).

Notes

The projectors satisfy \(\mathbb{J}:\mathbb{J}=\mathbb{J}\), \(\mathbb{K}:\mathbb{K}=\mathbb{K}\), \(\mathbb{J}:\mathbb{K}=0\), and \(\mathbb{J}+\mathbb{K}=\mathbb{I}^s\).

transverse_isotropic_projectors(axis)#

Construct transverse-isotropic (Walpole) fourth-rank projectors.

Parameters:

axis (array_like, shape (3,)) – Unit symmetry axis \(\hat{\mathbf{a}}\).

Returns:

E1, E2, E3, E4, F, G – Six Walpole basis tensors spanning the transverse-isotropic subspace. E1 and E2 are projectors; E3, E4 are their cross terms; F and G are the remaining orthogonal complements.

Return type:

SymmetricTensor4

Notes

The Walpole basis is not orthogonal in the usual sense: composition follows a \(2\times 2\) block rule for E1..E4 and scalar inversion for F and G. See TransverseIsotropicTensor4 for the inversion formula.

axl(A)[source]#

Axial vector of a skew-symmetric tensor.

The axial vector \(\mathbf{w}\) associated with \(\mathbf{W} = \operatorname{skw}(\mathbf{A})\) satisfies \(\mathbf{W}\mathbf{v} = \mathbf{w} \times \mathbf{v}\) and is given by \(w_i = -\tfrac{1}{2}\,\varepsilon_{ijk}\,W_{jk}\).

Parameters:

A (Tensor2)

Returns:

Shape (..., 3).

Return type:

jax.Array

dev(A)[source]#

Deviatoric part \(\mathbf{A} - \tfrac{1}{3}\operatorname{tr}(\mathbf{A})\,\mathbf{I}\).

Parameters:

A (Tensor2 or SymmetricTensor2)

Return type:

SymmetricTensor2

norm(A)[source]#

Frobenius norm \(\|\mathbf{A}\| = \sqrt{\mathbf{A}:\mathbf{A}}\).

For SymmetricTensor2 operands the double contraction is evaluated as a Kelvin dot product with no dense intermediate.

Parameters:

A (Tensor2 or SymmetricTensor2)

Returns:

Scalar (or batch of scalars).

Return type:

jax.Array

polar(F, mode='RU')[source]#

Polar decomposition \(\mathbf{F} = \mathbf{R}\mathbf{U}\) or \(\mathbf{F} = \mathbf{V}\mathbf{R}\).

Parameters:
  • F (Tensor2) – Deformation gradient.

  • mode ({"RU", "VR"}, optional) – Selects the right polar decomposition ("RU", default) or the left polar decomposition ("VR").

Returns:

(R, U) for mode="RU" where R is a Tensor2 (rotation) and U a SymmetricTensor2 (right stretch), or (V, R) for mode="VR" where V is a SymmetricTensor2 (left stretch).

Return type:

tuple

skw(A)[source]#

Skew-symmetric part \((\mathbf{A} - \mathbf{A}^{\mathsf{T}}) / 2\).

Parameters:

A (Tensor2)

Return type:

Tensor2

stretch_tensor(F)[source]#

Right stretch tensor \(\mathbf{U} = (\mathbf{F}^{\mathsf{T}}\mathbf{F})^{1/2}\).

Convenience wrapper around polar().

Parameters:

F (Tensor2) – Deformation gradient.

Return type:

SymmetricTensor2

sym(A)[source]#

Symmetric part \((\mathbf{A} + \mathbf{A}^{\mathsf{T}}) / 2\).

Parameters:

A (Tensor2)

Return type:

SymmetricTensor2

tr(A)[source]#

Trace \(\operatorname{tr}(\mathbf{A}) = A_{ii}\).

Parameters:

A (Tensor2)

Returns:

Scalar (or batch of scalars).

Return type:

jax.Array

vol(A)[source]#

Volumetric (spherical) part \(\tfrac{1}{3}\operatorname{tr}(\mathbf{A})\,\mathbf{I}\).

Complement of dev(): vol(A) + dev(A) == A for symmetric A.

Parameters:

A (Tensor2 or SymmetricTensor2)

Return type:

SymmetricTensor2

von_mises(sig)[source]#

Von Mises equivalent stress.

\[\sigma_\text{VM} = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}, \qquad \mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\]
Parameters:

sig (Tensor2 or SymmetricTensor2) – Cauchy stress tensor.

Returns:

Scalar (or batch of scalars).

Return type:

jax.Array

det33(A)[source]#

Determinant \(\det(\mathbf{A})\) of a \(3 \times 3\) matrix.

Evaluated via the explicit Sarrus formula, avoiding jnp.linalg.det overhead for the fixed-size case.

Parameters:

A (array_like, shape (3, 3))

Returns:

Scalar determinant.

Return type:

jax.Array

eig33(A, rtol=1e-16)[source]#
expm(A)[source]#

Matrix exponential \(\exp(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:

A (array_like, shape (3, 3))

Returns:

Shape (3, 3).

Return type:

jax.Array

inv_sqrtm(A)[source]#

Inverse square root \(\mathbf{A}^{-1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.

Computed jointly with sqrtm() via the same closed-form expression, so both are available at identical cost.

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Return type:

Array

Returns:

  • jax.Array – Shape (3, 3).

  • .. admonition:: References – :class: seealso

    Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.

isotropic_function(fun, A)[source]#

Isotropic matrix function \(f(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.

Evaluates the spectral decomposition \(f(\mathbf{A}) = \sum_{i=1}^{3} f(\lambda_i)\,\mathbf{n}_i \otimes \mathbf{n}_i\) where \(\lambda_i\) are the eigenvalues and \(\mathbf{n}_i\) the corresponding eigenvectors of \(\mathbf{A}\).

Parameters:
  • fun (callable) – Scalar function \(f : \mathbb{R} \to \mathbb{R}\) applied to each eigenvalue.

  • A (array_like, shape (3, 3)) – Real symmetric matrix.

Returns:

Shape (3, 3).

Return type:

jax.Array

logm(A)[source]#

Matrix logarithm \(\log(\mathbf{A})\) of a symmetric positive definite \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Returns:

Shape (3, 3).

Return type:

jax.Array

main_invariants(A)[source]#

Main (trace-power) invariants \((J_1, J_2, J_3)\) of a \(3 \times 3\) matrix.

\[J_k = \operatorname{tr}(\mathbf{A}^k), \quad k = 1, 2, 3\]
Parameters:

A (array_like, shape (3, 3))

Returns:

J1, J2, J3 – Three scalar invariants.

Return type:

jax.Array

powm(A, m)[source]#

Matrix power \(\mathbf{A}^m\) of a symmetric \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:
  • A (array_like, shape (3, 3))

  • m (float) – Exponent.

Returns:

Shape (3, 3).

Return type:

jax.Array

pq_invariants(sig)[source]#

Hydrostatic pressure \(p\) and deviatoric equivalent stress \(q\).

Commonly used in soil mechanics and pressure-sensitive plasticity.

\[p = -\tfrac{1}{3}\operatorname{tr}(\boldsymbol{\sigma}), \qquad q = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}\]

where \(\mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\).

Parameters:

sig (array_like, shape (3, 3)) – Cauchy stress tensor.

Return type:

tuple[Array, Array]

Returns:

  • p (jax.Array) – Mean pressure (positive in compression).

  • q (jax.Array) – Von Mises equivalent stress.

principal_invariants(A)[source]#

Principal invariants \((I_1, I_2, I_3)\) of a \(3 \times 3\) matrix \(\mathbf{A}\).

\[I_1 = \operatorname{tr}(\mathbf{A}), \quad I_2 = \tfrac{1}{2}\bigl(\operatorname{tr}(\mathbf{A})^2 - \operatorname{tr}(\mathbf{A}^2)\bigr), \quad I_3 = \det(\mathbf{A})\]
Parameters:

A (array_like, shape (3, 3))

Returns:

I1, I2, I3 – Three scalar invariants.

Return type:

jax.Array

sqrtm(A)[source]#

Matrix square root \(\mathbf{A}^{1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.

Uses the closed-form expression of Simo & Hughes (1998) based on the principal invariants of \(\mathbf{A}^{1/2}\). Accepts any object that implements __jax_array__ (e.g. a SymmetricTensor2).

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Return type:

Array

Returns:

  • jax.Array – Shape (3, 3).

  • .. admonition:: References – :class: seealso

    Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.

safe_fun(fun, x, norm=None, eps=1e-16)[source]#

Apply a function safely, avoiding evaluation at or near zero.

The input x is replaced by a small positive sentinel eps whenever norm(x) <= eps before calling fun. The final result is then masked to return zero in that case. This sentinel strategy ensures that fun is always evaluated at a numerically safe point, so that gradients through fun remain finite under automatic differentiation.

This is consistent with safe_sqrt(): safe_fun(jnp.sqrt, x) produces the same values and gradients as safe_sqrt(x).

Parameters:
  • fun (callable) – Scalar or array function to apply safely.

  • x (array_like) – Input value.

  • norm (callable, optional) – Scalar-valued function of x used to test proximity to zero. Defaults to the identity (i.e. x itself is the magnitude).

  • eps (float, optional) – Threshold below which x is considered zero. Defaults to 1e-16.

Returns:

fun(x) where norm(x) > eps, otherwise 0.

Return type:

jax.Array

Notes

The key property is that the sentinel-substituted input eps (not 0) is passed to fun in the masked branch. This prevents jax.grad from encountering undefined derivatives (e.g. 1 / (2 sqrt(0)) for fun = jnp.sqrt).

safe_norm(x, eps=1e-16, **kwargs)[source]#

Wrapper around optax.safe_norm that computes a numerically stable norm.

This function prevents numerical instability when computing vector norms for small magnitudes by internally applying a stability threshold.

Parameters:
  • x (array-like) – Input vector or tensor.

  • eps (float, optional) – Small constant added for numerical stability. Defaults to 1e-16.

  • **kwargs – Additional arguments passed to optax.safe_norm.

Returns:

The numerically stable norm of x.

Return type:

array-like

safe_sqrt(x, eps=1e-16)[source]#

Computes a numerically safe square root.

Ensures the argument to the square root is greater than eps to avoid taking the square root of zero or negative values, which could cause instability or NaNs.

Parameters:
  • x (array-like) – Input array or tensor.

  • eps (float, optional) – Minimum threshold for x before taking the square root. Defaults to 1e-16.

Returns:

The square root of x for x > eps, otherwise eps.

Return type:

array-like