from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from .utils import safe_norm, safe_sqrt
[docs]
def dim(A):
r"""Dimension ``dim`` of a n-rank matrix $\bA$, assuming ``shape=(dim, dim, ..., dim)``."""
return A.shape[0]
[docs]
def tr(A):
r"""
Trace of a matrix $\bA$.
$$\tr(\bA)=A_{ii}$$
"""
return jnp.trace(A)
[docs]
def dev(A):
r"""
Deviatoric part of a $d\times d$ matrix $\bA$.
$$\dev(\bA) = \bA - \dfrac{1}{d}\bI$$
"""
d = dim(A)
Id = jnp.eye(d)
return A - tr(A) / d * Id
[docs]
def det33(A):
r"""Determinant $\det(\bA)$ of a 3x3 matrix $\bA$, computed using explicit formula."""
a11, a12, a13 = A[0, 0], A[0, 1], A[0, 2]
a21, a22, a23 = A[1, 0], A[1, 1], A[1, 2]
a31, a32, a33 = A[2, 0], A[2, 1], A[2, 2]
return (
a11 * (a22 * a33 - a23 * a32)
- a12 * (a21 * a33 - a23 * a31)
+ a13 * (a21 * a32 - a22 * a31)
)
[docs]
def inv33(A):
r"""Inverse $\bA^{-1}$ of a 3x3 matrix $\bA$, explicitly computed using cofactor formula."""
# Minors and cofactors
a11, a12, a13 = A[0, 0], A[0, 1], A[0, 2]
a21, a22, a23 = A[1, 0], A[1, 1], A[1, 2]
a31, a32, a33 = A[2, 0], A[2, 1], A[2, 2]
# Cofactor matrix (transposed for adjugate directly)
cof = jnp.array(
[
[a22 * a33 - a23 * a32, a13 * a32 - a12 * a33, a12 * a23 - a13 * a22],
[a23 * a31 - a21 * a33, a11 * a33 - a13 * a31, a13 * a21 - a11 * a23],
[a21 * a32 - a22 * a31, a12 * a31 - a11 * a32, a11 * a22 - a12 * a21],
]
)
det = (
a11 * (a22 * a33 - a23 * a32)
- a12 * (a21 * a33 - a23 * a31)
+ a13 * (a21 * a32 - a22 * a31)
)
invA = cof / det
return invA
[docs]
def principal_invariants(A):
r"""Principal invariants of a 3x3 matrix $\bA$.
$$\begin{align*}
I_1 &= \tr(\bA)\\
I_2 &= \frac{1}{2}(\tr(\bA)^2-\tr(\bA^2))\\
I_3 &= \det(\bA)
\end{align*}$$
"""
i1 = jnp.trace(A)
i2 = (jnp.trace(A) ** 2 - jnp.trace(A @ A)) / 2
i3 = det33(A)
return i1, i2, i3
[docs]
def main_invariants(A):
r"""Main invariants of a 3x3 matrix $\bA$:
$$\tr(\bA),\: \tr(\bA^2),\: \tr(\bA^3)$$.
"""
j1 = jnp.trace(A)
j2 = jnp.trace(A @ A)
j3 = jnp.trace(A @ A @ A)
return j1, j2, j3
[docs]
def pq_invariants(sig):
r"""Hydrostatic/deviatoric equivalent stresses $(p,q)$. Typically used in soil mechanics.
$$p = - \tr(\bsig)/3 = -I_1/3$$
$$q = \sqrt{\frac{3}{2}\bs:\bs} = \sqrt{3 J_2}$$
"""
p = -jnp.trace(sig) / 3
s = dev(sig)
q = safe_sqrt(3.0 / 2.0 * jnp.vdot(s, s))
return p, q
[docs]
@partial(jax.jit, static_argnums=1)
def eig33(A, rtol=1e-16):
"""
Computes the eigenvalues and eigenvalue derivatives of a 3 x 3 real symmetric matrix.
This function implements a numerically stable eigendecomposition for 3 x 3 symmetric
matrices based on the method by Harari & Albocher (2023)
The implementation avoids catastrophic cancellation and loss of precision in
cases where two or more eigenvalues are nearly equal.
Parameters
----------
A : array_like of shape (3, 3)
Real symmetric matrix whose eigenvalues (and optionally eigenvalue dyads)
are to be computed.
rtol : float, optional
Relative tolerance used to determine near-isotropic or nearly repeated
eigenvalue cases. Default is `1e-16`.
Returns
-------
eigvals : jax.Array of shape (3,)
Eigenvalues of ``A\`, ordered in a consistent but unspecified order.
eigendyads : jax.Array of shape (3, 3, 3)
Derivatives of the eigenvalues with respect to the components of ``A\`,
obtained via forward-mode automatic differentiation (`jax.jacfwd`).
Notes
-----
- The method distinguishes three cases:
1. Near-isotropic case (``s < rtol * ||A||``): all eigenvalues are nearly equal.
2. Two nearly equal eigenvalues: handled by a special branch to ensure stability.
3. Three distinct eigenvalues: computed via trigonometric relations.
- The implementation uses ``safe_norm`` and ``safe_sqrt`` for numerical safety.
- Input ``A`` must be symmetric; asymmetry may lead to inaccurate results.
.. admonition:: References
:class: seealso
Harari, I., & Albocher, U. (2023). Computation of eigenvalues of a real,
symmetric 3 x 3 matrix with particular reference to the pernicious case of
two nearly equal eigenvalues. *International Journal for Numerical Methods in
Engineering*, 124(5), 1089-1110.
"""
# def dyad_3_distinct(A, lamb):
# """
# Hartmann, S. (2019) “Computational Aspects of the Symmetric Eigenvalue Problem of Second Order Tensors”,
# Technische Mechanik - European Journal of Engineering Mechanics, 23(2-4), pp. 283–294.
# Available at: https://journals.ub.ovgu.de/index.php/techmech/article/view/989
# """
# Id = jnp.eye(3)
# A2 = A @ A
# d = jnp.array([lamb[0] - lamb[1], lamb[1] - lamb[2], lamb[2] - lamb[1]])
# D = jnp.array([-d[0] * d[2], -d[0] * d[1], -d[1] * d[2]])
# h1 = jnp.array([lamb[1] * lamb[2], -(lamb[1] + lamb[2]), 1]) / D[0]
# N1 = h1[0] * Id + h1[1] * A + h1[2] * A2
# h2 = jnp.array([lamb[0] * lamb[2], -(lamb[0] + lamb[2]), 1]) / D[1]
# N2 = h2[0] * Id + h2[1] * A + h2[2] * A2
# h3 = jnp.array([lamb[0] * lamb[1], -(lamb[0] + lamb[1]), 1]) / D[2]
# N3 = h3[0] * Id + h3[1] * A + h3[2] * A2
# return (N1, N2, N3)
def compute_eigvals_HarariAlbocher(A):
"""
Eigendecomposition of 3x3 symmetric matrix based on Harari, I., & Albocher, U. (2023)
"""
A = jnp.asarray(A)
norm = safe_norm(A)
Id = jnp.eye(dim(A))
I1 = jnp.trace(A)
S = dev(A)
J2 = tr(S.T @ S) / 2
s = safe_sqrt(J2 / 3)
def branch_near_iso(_):
eigvals = jnp.ones((3,)) * I1 / 3
return eigvals, eigvals
def branch_general(_):
T = S @ S - 2 * J2 / 3 * Id
d = safe_norm(T - s * S) / safe_norm(T + s * S)
sj = jnp.sign(1 - d)
cond = sj * (1 - d) < rtol * norm
def branch_two_eigvals(_):
lamb_max = jnp.sqrt(3) * s
eigvals_dev = jnp.array([lamb_max, 0.0, -lamb_max])
eigvals = eigvals_dev + I1 / 3
return eigvals, eigvals
def branch_three_eigvals(_):
alpha = 2 / 3 * jnp.arctan(d**sj)
lambda_d = 2 * sj * s * jnp.cos(alpha)
sd = jnp.sqrt(3) * s * jnp.sin(alpha)
eigvals_dev = lax.cond(
lambda_d > 0,
lambda _: jnp.array(
[-lambda_d / 2 - sd, -lambda_d / 2 + sd, lambda_d]
),
lambda _: jnp.array(
[lambda_d, -lambda_d / 2 - sd, -lambda_d / 2 + sd]
),
operand=None,
)
eigvals = eigvals_dev + I1 / 3
return eigvals, eigvals
return lax.cond(
cond, branch_two_eigvals, branch_three_eigvals, operand=None
)
return lax.cond(s < rtol * norm, branch_near_iso, branch_general, operand=None)
eigendyads, eigvals = jax.jacfwd(compute_eigvals_HarariAlbocher, has_aux=True)(A)
return eigvals, eigendyads
def _sqrtm(C):
r"""
Unified expression for sqrt and inverse sqrt of a symmetric matrix $\bC$
Simo, J. C., & Hughes, T. J. (1998). Computational inelasticity., p.244
"""
Id = jnp.eye(3)
C2 = C @ C
eigvals, _ = eig33(C)
lamb = safe_sqrt(eigvals)
i1 = jnp.sum(lamb)
i2 = lamb[0] * lamb[1] + lamb[1] * lamb[2] + lamb[0] * lamb[2]
i3 = jnp.prod(lamb)
D = i1 * i2 - i3
U = 1 / D * (-C2 + (i1**2 - i2) * C + i1 * i3 * Id)
U_inv = 1 / i3 * (C - i1 * U + i2 * Id)
return U, U_inv
[docs]
def sqrtm(A):
r"""
Matrix square-root $\bA^{1/2}$ of a symmetric 3x3 matrix $\bA$.
Computed using the unified square root and inverse square root
formula, see Simo & Hughes, 1998.
.. admonition:: References
:class: seealso
Simo, J. C., & Hughes, T. J. (1998). Computational inelasticity., p.244
"""
return _sqrtm(A)[0]
[docs]
def inv_sqrtm(A):
r"""
Matrix inverse square-root $\bA^{-1/2}$ of a symmetric 3x3 matrix $\bA$.
Computed using the unified square root and inverse square root
formula, see Simo & Hughes, 1998.
.. admonition:: References
:class: seealso
Simo, J. C., & Hughes, T. J. (1998). Computational inelasticity., p.244
"""
return _sqrtm(A)[1]
[docs]
def isotropic_function(fun, A):
r"""Computes an isotropic function of a symmetric 3x3 matrix $\bA$.
Parameters
----------
fun : callable
A scalar function $f(x)$
A : jax.Array
A symmetric 3x3 matrix
Returns
-------
jax.Array
A new 3x3 matrix $f_{\bA}$ such that
$$f_{\bA} = \sum_{i=1}^3 f(\lambda_i) \bn_i \otimes \bn_i$$
where $\lambda_i$ and $\bn_i$ are the eigenvalues and eigenvectors of $\bA$.
"""
eigvals, eigendyads = eig33(A)
f = fun(eigvals)
return sum([fi * Ni for fi, Ni in zip(f, eigendyads)])
[docs]
def expm(A):
r"""Matrix exponential $\exp(\bA)$ of a symmetric 3x3 matrix $\bA$."""
return isotropic_function(jnp.exp, A)
[docs]
def logm(A):
r"""Matrix logarithm $\log(\bA)$ of a symmetric 3x3 matrix $\bA$."""
return isotropic_function(jnp.log, A)
[docs]
def powm(A, m):
r"""Matrix power $\bA^m$ of exponent $m$ of a symmetric 3x3 matrix $\bA$."""
return isotropic_function(lambda x: jnp.power(x, m), A)