jaxmat.tensors.linear_algebra module#

jaxmat/tensors/linear_algebra.py

Linear algebra operations on plain JAX arrays.

All functions operate on jax.Array objects of shape (3, 3) (or batched equivalents). Because tensor wrapper classes implement __jax_array__, any function here accepts a Tensor2 or SymmetricTensor2 directly — no explicit conversion is needed.

Public API#

det33, inv33, eig33, principal_invariants, main_invariants, pq_invariants, isotropic_function, sqrtm, inv_sqrtm, expm, logm, powm

Private (internal helpers)#

_dim, _tr, _dev, _sqrtm

det33(A)[source]#

Determinant \(\det(\mathbf{A})\) of a \(3 \times 3\) matrix.

Evaluated via the explicit Sarrus formula, avoiding jnp.linalg.det overhead for the fixed-size case.

Parameters:

A (array_like, shape (3, 3))

Returns:

Scalar determinant.

Return type:

jax.Array

inv33(A)[source]#

Inverse \(\mathbf{A}^{-1}\) of a \(3 \times 3\) matrix.

Computed via the explicit cofactor formula (adjugate divided by determinant), avoiding jnp.linalg.solve overhead for the fixed-size case.

Parameters:

A (array_like, shape (3, 3))

Returns:

Shape (3, 3).

Return type:

jax.Array

principal_invariants(A)[source]#

Principal invariants \((I_1, I_2, I_3)\) of a \(3 \times 3\) matrix \(\mathbf{A}\).

\[I_1 = \operatorname{tr}(\mathbf{A}), \quad I_2 = \tfrac{1}{2}\bigl(\operatorname{tr}(\mathbf{A})^2 - \operatorname{tr}(\mathbf{A}^2)\bigr), \quad I_3 = \det(\mathbf{A})\]
Parameters:

A (array_like, shape (3, 3))

Returns:

I1, I2, I3 – Three scalar invariants.

Return type:

jax.Array

main_invariants(A)[source]#

Main (trace-power) invariants \((J_1, J_2, J_3)\) of a \(3 \times 3\) matrix.

\[J_k = \operatorname{tr}(\mathbf{A}^k), \quad k = 1, 2, 3\]
Parameters:

A (array_like, shape (3, 3))

Returns:

J1, J2, J3 – Three scalar invariants.

Return type:

jax.Array

pq_invariants(sig)[source]#

Hydrostatic pressure \(p\) and deviatoric equivalent stress \(q\).

Commonly used in soil mechanics and pressure-sensitive plasticity.

\[p = -\tfrac{1}{3}\operatorname{tr}(\boldsymbol{\sigma}), \qquad q = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}\]

where \(\mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\).

Parameters:

sig (array_like, shape (3, 3)) – Cauchy stress tensor.

Return type:

tuple[Array, Array]

Returns:

  • p (jax.Array) – Mean pressure (positive in compression).

  • q (jax.Array) – Von Mises equivalent stress.

eig33_HA(A, rtol=1e-16)[source]#

Eigenvalues and eigenvalue dyads of a \(3 \times 3\) real symmetric matrix.

Implements the numerically stable method of Harari & Albocher (2023), which avoids catastrophic cancellation when two or more eigenvalues are nearly equal. Eigenvalue dyads (derivatives of eigenvalues with respect to \(\mathbf{A}\)) are obtained via jax.jacfwd.

Parameters:
  • A (array_like, shape (3, 3)) – Real symmetric matrix.

  • rtol (float, optional) – Relative tolerance for the near-isotropic and two-equal-eigenvalue branches. Defaults to 1e-16.

Return type:

tuple[Array, Array]

Returns:

  • eigvals (jax.Array, shape (3,)) – Eigenvalues in ascending order.

  • eigendyads (jax.Array, shape (3, 3, 3)) – Rank-1 projectors \(\mathbf{n}_i \otimes \mathbf{n}_i\) for each eigenvalue, forming the derivative \(\partial \lambda_i / \partial \mathbf{A}\).

Notes

The input must be symmetric; asymmetric components are silently ignored by the algorithm.

References

Harari, I., & Albocher, U. (2023). Computation of eigenvalues of a real, symmetric 3x3 matrix with particular reference to the pernicious case of two nearly equal eigenvalues. International Journal for Numerical Methods in Engineering, 124(5), 1089-1110.

eig33(A, rtol=1e-16)[source]#
isotropic_function(fun, A)[source]#

Isotropic matrix function \(f(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.

Evaluates the spectral decomposition \(f(\mathbf{A}) = \sum_{i=1}^{3} f(\lambda_i)\,\mathbf{n}_i \otimes \mathbf{n}_i\) where \(\lambda_i\) are the eigenvalues and \(\mathbf{n}_i\) the corresponding eigenvectors of \(\mathbf{A}\).

Parameters:
  • fun (callable) – Scalar function \(f : \mathbb{R} \to \mathbb{R}\) applied to each eigenvalue.

  • A (array_like, shape (3, 3)) – Real symmetric matrix.

Returns:

Shape (3, 3).

Return type:

jax.Array

sqrtm(A)[source]#

Matrix square root \(\mathbf{A}^{1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.

Uses the closed-form expression of Simo & Hughes (1998) based on the principal invariants of \(\mathbf{A}^{1/2}\). Accepts any object that implements __jax_array__ (e.g. a SymmetricTensor2).

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Return type:

Array

Returns:

  • jax.Array – Shape (3, 3).

  • .. admonition:: References – :class: seealso

    Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.

inv_sqrtm(A)[source]#

Inverse square root \(\mathbf{A}^{-1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.

Computed jointly with sqrtm() via the same closed-form expression, so both are available at identical cost.

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Return type:

Array

Returns:

  • jax.Array – Shape (3, 3).

  • .. admonition:: References – :class: seealso

    Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.

expm(A)[source]#

Matrix exponential \(\exp(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:

A (array_like, shape (3, 3))

Returns:

Shape (3, 3).

Return type:

jax.Array

logm(A)[source]#

Matrix logarithm \(\log(\mathbf{A})\) of a symmetric positive definite \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:

A (array_like, shape (3, 3)) – Symmetric positive definite matrix.

Returns:

Shape (3, 3).

Return type:

jax.Array

powm(A, m)[source]#

Matrix power \(\mathbf{A}^m\) of a symmetric \(3 \times 3\) matrix.

Computed via the spectral decomposition; see isotropic_function(). Accepts any object that implements __jax_array__.

Parameters:
  • A (array_like, shape (3, 3))

  • m (float) – Exponent.

Returns:

Shape (3, 3).

Return type:

jax.Array