jaxmat.tensors.linear_algebra module#
jaxmat/tensors/linear_algebra.py
Linear algebra operations on plain JAX arrays.
All functions operate on jax.Array objects of shape (3, 3) (or
batched equivalents). Because tensor wrapper classes implement
__jax_array__, any function here accepts a Tensor2
or SymmetricTensor2 directly — no explicit
conversion is needed.
Public API#
det33, inv33, eig33, principal_invariants, main_invariants, pq_invariants, isotropic_function, sqrtm, inv_sqrtm, expm, logm, powm
Private (internal helpers)#
_dim, _tr, _dev, _sqrtm
- det33(A)[source]#
Determinant \(\det(\mathbf{A})\) of a \(3 \times 3\) matrix.
Evaluated via the explicit Sarrus formula, avoiding
jnp.linalg.detoverhead for the fixed-size case.- Parameters:
A (array_like, shape (3, 3))
- Returns:
Scalar determinant.
- Return type:
jax.Array
- inv33(A)[source]#
Inverse \(\mathbf{A}^{-1}\) of a \(3 \times 3\) matrix.
Computed via the explicit cofactor formula (adjugate divided by determinant), avoiding
jnp.linalg.solveoverhead for the fixed-size case.- Parameters:
A (array_like, shape (3, 3))
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- principal_invariants(A)[source]#
Principal invariants \((I_1, I_2, I_3)\) of a \(3 \times 3\) matrix \(\mathbf{A}\).
\[I_1 = \operatorname{tr}(\mathbf{A}), \quad I_2 = \tfrac{1}{2}\bigl(\operatorname{tr}(\mathbf{A})^2 - \operatorname{tr}(\mathbf{A}^2)\bigr), \quad I_3 = \det(\mathbf{A})\]- Parameters:
A (array_like, shape (3, 3))
- Returns:
I1, I2, I3 – Three scalar invariants.
- Return type:
jax.Array
- main_invariants(A)[source]#
Main (trace-power) invariants \((J_1, J_2, J_3)\) of a \(3 \times 3\) matrix.
\[J_k = \operatorname{tr}(\mathbf{A}^k), \quad k = 1, 2, 3\]- Parameters:
A (array_like, shape (3, 3))
- Returns:
J1, J2, J3 – Three scalar invariants.
- Return type:
jax.Array
- pq_invariants(sig)[source]#
Hydrostatic pressure \(p\) and deviatoric equivalent stress \(q\).
Commonly used in soil mechanics and pressure-sensitive plasticity.
\[p = -\tfrac{1}{3}\operatorname{tr}(\boldsymbol{\sigma}), \qquad q = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}}\]where \(\mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})\).
- Parameters:
sig (array_like, shape (3, 3)) – Cauchy stress tensor.
- Return type:
tuple[Array,Array]- Returns:
p (jax.Array) – Mean pressure (positive in compression).
q (jax.Array) – Von Mises equivalent stress.
- eig33_HA(A, rtol=1e-16)[source]#
Eigenvalues and eigenvalue dyads of a \(3 \times 3\) real symmetric matrix.
Implements the numerically stable method of Harari & Albocher (2023), which avoids catastrophic cancellation when two or more eigenvalues are nearly equal. Eigenvalue dyads (derivatives of eigenvalues with respect to \(\mathbf{A}\)) are obtained via
jax.jacfwd.- Parameters:
A (array_like, shape (3, 3)) – Real symmetric matrix.
rtol (float, optional) – Relative tolerance for the near-isotropic and two-equal-eigenvalue branches. Defaults to
1e-16.
- Return type:
tuple[Array,Array]- Returns:
eigvals (jax.Array, shape (3,)) – Eigenvalues in ascending order.
eigendyads (jax.Array, shape (3, 3, 3)) – Rank-1 projectors \(\mathbf{n}_i \otimes \mathbf{n}_i\) for each eigenvalue, forming the derivative \(\partial \lambda_i / \partial \mathbf{A}\).
Notes
The input must be symmetric; asymmetric components are silently ignored by the algorithm.
References
Harari, I., & Albocher, U. (2023). Computation of eigenvalues of a real, symmetric 3x3 matrix with particular reference to the pernicious case of two nearly equal eigenvalues. International Journal for Numerical Methods in Engineering, 124(5), 1089-1110.
- isotropic_function(fun, A)[source]#
Isotropic matrix function \(f(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.
Evaluates the spectral decomposition \(f(\mathbf{A}) = \sum_{i=1}^{3} f(\lambda_i)\,\mathbf{n}_i \otimes \mathbf{n}_i\) where \(\lambda_i\) are the eigenvalues and \(\mathbf{n}_i\) the corresponding eigenvectors of \(\mathbf{A}\).
- Parameters:
fun (callable) – Scalar function \(f : \mathbb{R} \to \mathbb{R}\) applied to each eigenvalue.
A (array_like, shape (3, 3)) – Real symmetric matrix.
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- sqrtm(A)[source]#
Matrix square root \(\mathbf{A}^{1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.
Uses the closed-form expression of Simo & Hughes (1998) based on the principal invariants of \(\mathbf{A}^{1/2}\). Accepts any object that implements
__jax_array__(e.g. aSymmetricTensor2).- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Return type:
Array- Returns:
jax.Array – Shape (3, 3).
.. admonition:: References – :class: seealso
Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.
- inv_sqrtm(A)[source]#
Inverse square root \(\mathbf{A}^{-1/2}\) of a symmetric positive definite \(3 \times 3\) matrix.
Computed jointly with
sqrtm()via the same closed-form expression, so both are available at identical cost.- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Return type:
Array- Returns:
jax.Array – Shape (3, 3).
.. admonition:: References – :class: seealso
Simo, J. C., & Hughes, T. J. R. (1998). Computational Inelasticity. Springer. p. 244.
- expm(A)[source]#
Matrix exponential \(\exp(\mathbf{A})\) of a symmetric \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3))
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- logm(A)[source]#
Matrix logarithm \(\log(\mathbf{A})\) of a symmetric positive definite \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3)) – Symmetric positive definite matrix.
- Returns:
Shape (3, 3).
- Return type:
jax.Array
- powm(A, m)[source]#
Matrix power \(\mathbf{A}^m\) of a symmetric \(3 \times 3\) matrix.
Computed via the spectral decomposition; see
isotropic_function(). Accepts any object that implements__jax_array__.- Parameters:
A (array_like, shape (3, 3))
m (float) – Exponent.
- Returns:
Shape (3, 3).
- Return type:
jax.Array