Source code for jaxmat.tensors.linear_algebra

"""
jaxmat/tensors/linear_algebra.py

Linear algebra operations on plain JAX arrays.

All functions operate on ``jax.Array`` objects of shape ``(3, 3)`` (or
batched equivalents).  Because tensor wrapper classes implement
``__jax_array__``, any function here accepts a :class:`~jaxmat.tensors.Tensor2`
or :class:`~jaxmat.tensors.SymmetricTensor2` directly — no explicit
conversion is needed.

Public API
----------
det33, inv33, eig33,
principal_invariants, main_invariants, pq_invariants,
isotropic_function, sqrtm, inv_sqrtm, expm, logm, powm

Private (internal helpers)
--------------------------
_dim, _tr, _dev, _sqrtm
"""

from functools import partial

import jax
import jax.numpy as jnp
from jax import lax

from .utils import safe_norm, safe_sqrt

# ─────────────────────────────────────────────────────────────────────────────
# Internal helpers
# ─────────────────────────────────────────────────────────────────────────────


def _dim(A) -> int:
    r"""Spatial dimension of a square matrix $\mathbf{A}$, inferred from its shape."""
    return jnp.asarray(A).shape[0]


def _tr(A) -> jax.Array:
    r"""Trace $\operatorname{tr}(\mathbf{A}) = A_{ii}$."""
    return jnp.trace(A)


def _dev(A) -> jax.Array:
    r"""
    Deviatoric part of a $d \times d$ matrix $\mathbf{A}$.

    .. math:: \operatorname{dev}(\mathbf{A}) = \mathbf{A} -
        \frac{1}{d}\operatorname{tr}(\mathbf{A})\,\mathbf{I}
    """
    d = _dim(A)
    return A - _tr(A) / d * jnp.eye(d)


# ─────────────────────────────────────────────────────────────────────────────
# Public array-level operations
# ─────────────────────────────────────────────────────────────────────────────


[docs] def det33(A) -> jax.Array: r""" Determinant $\det(\mathbf{A})$ of a $3 \times 3$ matrix. Evaluated via the explicit Sarrus formula, avoiding ``jnp.linalg.det`` overhead for the fixed-size case. Parameters ---------- A : array_like, shape (3, 3) Returns ------- jax.Array Scalar determinant. """ a11, a12, a13 = A[0, 0], A[0, 1], A[0, 2] a21, a22, a23 = A[1, 0], A[1, 1], A[1, 2] a31, a32, a33 = A[2, 0], A[2, 1], A[2, 2] return ( a11 * (a22 * a33 - a23 * a32) - a12 * (a21 * a33 - a23 * a31) + a13 * (a21 * a32 - a22 * a31) )
[docs] def inv33(A) -> jax.Array: r""" Inverse $\mathbf{A}^{-1}$ of a $3 \times 3$ matrix. Computed via the explicit cofactor formula (adjugate divided by determinant), avoiding ``jnp.linalg.solve`` overhead for the fixed-size case. Parameters ---------- A : array_like, shape (3, 3) Returns ------- jax.Array Shape (3, 3). """ a11, a12, a13 = A[0, 0], A[0, 1], A[0, 2] a21, a22, a23 = A[1, 0], A[1, 1], A[1, 2] a31, a32, a33 = A[2, 0], A[2, 1], A[2, 2] cof = jnp.array( [ [a22 * a33 - a23 * a32, a13 * a32 - a12 * a33, a12 * a23 - a13 * a22], [a23 * a31 - a21 * a33, a11 * a33 - a13 * a31, a13 * a21 - a11 * a23], [a21 * a32 - a22 * a31, a12 * a31 - a11 * a32, a11 * a22 - a12 * a21], ] ) det = ( a11 * (a22 * a33 - a23 * a32) - a12 * (a21 * a33 - a23 * a31) + a13 * (a21 * a32 - a22 * a31) ) return cof / det
[docs] def principal_invariants(A) -> tuple[jax.Array, jax.Array, jax.Array]: r""" Principal invariants $(I_1, I_2, I_3)$ of a $3 \times 3$ matrix $\mathbf{A}$. .. math:: I_1 = \operatorname{tr}(\mathbf{A}), \quad I_2 = \tfrac{1}{2}\bigl(\operatorname{tr}(\mathbf{A})^2 - \operatorname{tr}(\mathbf{A}^2)\bigr), \quad I_3 = \det(\mathbf{A}) Parameters ---------- A : array_like, shape (3, 3) Returns ------- I1, I2, I3 : jax.Array Three scalar invariants. """ i1 = jnp.trace(A) i2 = (jnp.trace(A) ** 2 - jnp.trace(A @ A)) / 2 i3 = det33(A) return i1, i2, i3
[docs] def main_invariants(A) -> tuple[jax.Array, jax.Array, jax.Array]: r""" Main (trace-power) invariants $(J_1, J_2, J_3)$ of a $3 \times 3$ matrix. .. math:: J_k = \operatorname{tr}(\mathbf{A}^k), \quad k = 1, 2, 3 Parameters ---------- A : array_like, shape (3, 3) Returns ------- J1, J2, J3 : jax.Array Three scalar invariants. """ j1 = jnp.trace(A) j2 = jnp.trace(A @ A) j3 = jnp.trace(A @ A @ A) return j1, j2, j3
[docs] def pq_invariants(sig) -> tuple[jax.Array, jax.Array]: r""" Hydrostatic pressure $p$ and deviatoric equivalent stress $q$. Commonly used in soil mechanics and pressure-sensitive plasticity. .. math:: p = -\tfrac{1}{3}\operatorname{tr}(\boldsymbol{\sigma}), \qquad q = \sqrt{\tfrac{3}{2}\,\mathbf{s}:\mathbf{s}} where $\mathbf{s} = \operatorname{dev}(\boldsymbol{\sigma})$. Parameters ---------- sig : array_like, shape (3, 3) Cauchy stress tensor. Returns ------- p : jax.Array Mean pressure (positive in compression). q : jax.Array Von Mises equivalent stress. """ p = -jnp.trace(sig) / 3 s = _dev(sig) q = safe_sqrt(3.0 / 2.0 * jnp.vdot(s, s)) return p, q
[docs] def eig33_HA(A, rtol=1e-16) -> tuple[jax.Array, jax.Array]: r""" Eigenvalues and eigenvalue dyads of a $3 \times 3$ real symmetric matrix. Implements the numerically stable method of Harari & Albocher (2023), which avoids catastrophic cancellation when two or more eigenvalues are nearly equal. Eigenvalue dyads (derivatives of eigenvalues with respect to $\mathbf{A}$) are obtained via ``jax.jacfwd``. Parameters ---------- A : array_like, shape (3, 3) Real symmetric matrix. rtol : float, optional Relative tolerance for the near-isotropic and two-equal-eigenvalue branches. Defaults to ``1e-16``. Returns ------- eigvals : jax.Array, shape (3,) Eigenvalues in ascending order. eigendyads : jax.Array, shape (3, 3, 3) Rank-1 projectors $\mathbf{n}_i \otimes \mathbf{n}_i$ for each eigenvalue, forming the derivative $\partial \lambda_i / \partial \mathbf{A}$. Notes ----- The input must be symmetric; asymmetric components are silently ignored by the algorithm. .. admonition:: References :class: seealso Harari, I., & Albocher, U. (2023). Computation of eigenvalues of a real, symmetric 3x3 matrix with particular reference to the pernicious case of two nearly equal eigenvalues. *International Journal for Numerical Methods in Engineering*, 124(5), 1089-1110. """ def _compute(A): A = jnp.asarray(A) norm = safe_norm(A) Id = jnp.eye(_dim(A)) I1 = jnp.trace(A) S = _dev(A) J2 = _tr(S.T @ S) / 2 s = safe_sqrt(J2 / 3) def _near_iso(_): ev = jnp.ones((3,)) * I1 / 3 return ev, ev def _general(_): T = S @ S - 2 * J2 / 3 * Id d = safe_norm(T - s * S) / safe_norm(T + s * S) sj = jnp.sign(1 - d) cond = sj * (1 - d) < rtol * norm def _two(_): lm = jnp.sqrt(3) * s ev = jnp.array([lm, 0.0, -lm]) + I1 / 3 return ev, ev def _three(_): alpha = 2 / 3 * jnp.arctan2(safe_norm(T - s * S) ** sj, safe_norm(T + s * S) ** sj) ld = 2 * sj * s * jnp.cos(alpha) sd = jnp.sqrt(3) * s * jnp.sin(alpha) ev_dev = jnp.array([-ld / 2 - sd, -ld / 2 + sd, ld]) return ev_dev + I1 / 3, ev_dev + I1 / 3 return lax.cond(cond, _two, _three, operand=None) return lax.cond(s < rtol * norm, _near_iso, _general, operand=None) eigendyads, eigvals = jax.jacfwd(_compute, has_aux=True)(A) order = jnp.argsort(eigvals) eigvals = eigvals[order] eigendyads = eigendyads[order] eigendyads = 0.5 * (eigendyads + jnp.swapaxes(eigendyads, -1, -2)) return eigvals, eigendyads
[docs] @partial(jax.jit, static_argnums=1) def eig33(A, rtol=1e-16): def J2s(A): d0 = A[0, 0] - A[1, 1] d1 = A[0, 0] - A[2, 2] d2 = A[1, 1] - A[2, 2] offdiag = A[0, 1] ** 2 + A[0, 2] ** 2 + A[1, 2] ** 2 diag = (d0**2 + d1**2 + d2**2) / 6.0 return offdiag + diag def J3s(A): d0 = A[0, 0] - A[1, 1] d1 = A[0, 0] - A[2, 2] d2 = A[1, 1] - A[2, 2] t1 = d1 + d2 t2 = d0 - d2 t3 = -d0 - d1 offdiag = 2.0 * A[0, 1] * A[1, 2] * A[0, 2] mixed = (A[0, 1] ** 2 * t1 + A[0, 2] ** 2 * t2 + A[1, 2] ** 2 * t3) / 3.0 diag = (t1 * t2 * t3) / 27.0 return offdiag + mixed - diag def dxs(A): d0 = A[0, 0] - A[1, 1] d1 = A[0, 0] - A[2, 2] d2 = A[1, 1] - A[2, 2] w = A[0, 1] v = A[0, 2] u = A[1, 2] alpha = d2 beta = -d1 gamma = d0 return jnp.asarray( [ 3.0 * jnp.sqrt(3.0) * (v * w * alpha + u * (v * v - w * w)), alpha * beta * gamma + alpha * u * u + beta * v * v + gamma * w * w, 2.0 * u * beta * gamma - v * w * (beta - gamma) + u * (2.0 * u * u - v * v - w * w), 2.0 * (v * alpha * gamma + u * w * (beta - gamma) + v * (v * v + w * w - 2.0 * u * u)), 2.0 * (w * alpha * beta + u * v * (beta - gamma) + w * (v * v + w * w - 2.0 * u * u)), ], dtype=A.dtype, ) def discs(A): terms = dxs(A) return jnp.sum(terms * terms) def compute_eigvals(A): A = jnp.asarray(A) I1 = jnp.trace(A) j2 = J2s(A) j3 = J3s(A) discriminant = discs(A) normA = safe_norm(A) def branch_near_iso(_): eigvals = jnp.ones((3,), dtype=A.dtype) * I1 / 3.0 return eigvals, eigvals def branch_general(_): phi = jnp.arctan2(safe_sqrt(27.0 * discriminant), 27.0 * j3) amplitude = 2.0 * safe_sqrt(3.0 * j2) shifts = 2.0 * jnp.pi * jnp.asarray([1.0, 2.0, 3.0], dtype=A.dtype) eigvals = (amplitude * jnp.cos((phi + shifts) / 3.0) + I1) / 3.0 return eigvals, eigvals return lax.cond(j2 < rtol * normA, branch_near_iso, branch_general, operand=None) eigendyads, eigvals = jax.jacfwd(compute_eigvals, has_aux=True)(A) order = jnp.argsort(eigvals) eigvals = eigvals[order] eigendyads = eigendyads[order] eigendyads = 0.5 * (eigendyads + jnp.swapaxes(eigendyads, -1, -2)) return eigvals, eigendyads
# ───────────────────────────────────────────────────────────────────────────── # Private (implementation detail) # ───────────────────────────────────────────────────────────────────────────── def _sqrtm(C) -> tuple[jax.Array, jax.Array]: r""" Square root $\mathbf{U} = \mathbf{C}^{1/2}$ and inverse square root $\mathbf{U}^{-1}$ of a symmetric positive definite $3 \times 3$ matrix. Uses the closed-form expression due to Simo & Hughes (1998), p. 244, based on the principal invariants of $\mathbf{U}$. Parameters ---------- C : array_like, shape (3, 3) Symmetric positive definite matrix (typically the right Cauchy-Green deformation tensor $\mathbf{C} = \mathbf{F}^{\mathsf{T}}\mathbf{F}$). Returns ------- U : jax.Array, shape (3, 3) Matrix square root $\mathbf{C}^{1/2}$. U_inv : jax.Array, shape (3, 3) Inverse square root $\mathbf{C}^{-1/2}$. .. admonition:: References :class: seealso Simo, J. C., & Hughes, T. J. R. (1998). *Computational Inelasticity*. Springer. p. 244. """ Id = jnp.eye(3) C2 = C @ C eigvals, _ = eig33(C) lamb = safe_sqrt(eigvals) i1 = jnp.sum(lamb) i2 = lamb[0] * lamb[1] + lamb[1] * lamb[2] + lamb[0] * lamb[2] i3 = jnp.prod(lamb) D = i1 * i2 - i3 U = 1 / D * (-C2 + (i1**2 - i2) * C + i1 * i3 * Id) U_inv = 1 / i3 * (C - i1 * U + i2 * Id) return U, U_inv
[docs] def isotropic_function(fun, A) -> jax.Array: r""" Isotropic matrix function $f(\mathbf{A})$ of a symmetric $3 \times 3$ matrix. Evaluates the spectral decomposition $f(\mathbf{A}) = \sum_{i=1}^{3} f(\lambda_i)\,\mathbf{n}_i \otimes \mathbf{n}_i$ where $\lambda_i$ are the eigenvalues and $\mathbf{n}_i$ the corresponding eigenvectors of $\mathbf{A}$. Parameters ---------- fun : callable Scalar function $f : \mathbb{R} \to \mathbb{R}$ applied to each eigenvalue. A : array_like, shape (3, 3) Real symmetric matrix. Returns ------- jax.Array Shape (3, 3). """ eigvals, projectors = eig33(jnp.asarray(A)) # The projectors P_i = ∂λ_i/∂A come from jacfwd, which perturbs each # element of A independently (without enforcing A_{ij} = A_{ji}). This # means P_i may have a small antisymmetric component. Symmetrising before # the spectral reconstruction removes that noise and ensures f(A) is # symmetric, consistent with the reference implementation. projectors = 0.5 * (projectors + jnp.swapaxes(projectors, -1, -2)) return jnp.einsum("a,aij->ij", fun(eigvals), projectors)
[docs] def sqrtm(A) -> jax.Array: r""" Matrix square root $\mathbf{A}^{1/2}$ of a symmetric positive definite $3 \times 3$ matrix. Uses the closed-form expression of Simo & Hughes (1998) based on the principal invariants of $\mathbf{A}^{1/2}$. Accepts any object that implements ``__jax_array__`` (e.g. a :class:`~jaxmat.tensors.SymmetricTensor2`). Parameters ---------- A : array_like, shape (3, 3) Symmetric positive definite matrix. Returns ------- jax.Array Shape (3, 3). .. admonition:: References :class: seealso Simo, J. C., & Hughes, T. J. R. (1998). *Computational Inelasticity*. Springer. p. 244. """ return _sqrtm(jnp.asarray(A))[0]
[docs] def inv_sqrtm(A) -> jax.Array: r""" Inverse square root $\mathbf{A}^{-1/2}$ of a symmetric positive definite $3 \times 3$ matrix. Computed jointly with :func:`sqrtm` via the same closed-form expression, so both are available at identical cost. Parameters ---------- A : array_like, shape (3, 3) Symmetric positive definite matrix. Returns ------- jax.Array Shape (3, 3). .. admonition:: References :class: seealso Simo, J. C., & Hughes, T. J. R. (1998). *Computational Inelasticity*. Springer. p. 244. """ return _sqrtm(jnp.asarray(A))[1]
[docs] def expm(A) -> jax.Array: r""" Matrix exponential $\exp(\mathbf{A})$ of a symmetric $3 \times 3$ matrix. Computed via the spectral decomposition; see :func:`isotropic_function`. Accepts any object that implements ``__jax_array__``. Parameters ---------- A : array_like, shape (3, 3) Returns ------- jax.Array Shape (3, 3). """ return isotropic_function(jnp.exp, jnp.asarray(A))
[docs] def logm(A) -> jax.Array: r""" Matrix logarithm $\log(\mathbf{A})$ of a symmetric positive definite $3 \times 3$ matrix. Computed via the spectral decomposition; see :func:`isotropic_function`. Accepts any object that implements ``__jax_array__``. Parameters ---------- A : array_like, shape (3, 3) Symmetric positive definite matrix. Returns ------- jax.Array Shape (3, 3). """ return isotropic_function(jnp.log, jnp.asarray(A))
[docs] def powm(A, m) -> jax.Array: r""" Matrix power $\mathbf{A}^m$ of a symmetric $3 \times 3$ matrix. Computed via the spectral decomposition; see :func:`isotropic_function`. Accepts any object that implements ``__jax_array__``. Parameters ---------- A : array_like, shape (3, 3) m : float Exponent. Returns ------- jax.Array Shape (3, 3). """ return isotropic_function(lambda x: jnp.power(x, m), jnp.asarray(A))