Computational aspects#
jaxmat inherits all the standard features of the JAX ecosystem: automatic differentiation (AD), automatic vectorization (vmap), and just-in-time (JIT) compilation, while also adhering to JAX’s strict rules such as function purity, immutability of arrays, and program tracing. This section highlights aspects specific to jaxmat; for a general introduction to these concepts, see the JAX documentation.
Precision#
As material models generally involve relatively complex nonlinear equations to solve, we recommend working in float64 precision for better stability and accuracy rather than in the default float32 precision commonly used in machine learning applications. When importing jaxmat, float64 precision is set by default. Precision can be set manually if needed as:
jax.config.update("jax_enable_x64", True)
Just-In-Time Compilation and backend device#
In JAX, it is usually recommended to JIT only the outermost function of a computation.
In jaxmat, this means we typically JIT the constitutive update function, or a batched wrapper around it.
When JIT-tracing, concrete values are replaced by tracers, and operations must remain compatible with JAX’s functional semantics. Care must be taken when writing implementations involving conditionals or loops to ensure that the traced code remains valid.
As in JAX, jaxmat supports device-portable batched constitutive updates; users may run the same code on CPU or GPU.
The observed performance will strongly depends on the hardware device, the used batch size and the computational intensity of the material model, see the demos/performance.md demo for more details.
On Automatic Vectorization#
jaxmat relies heavily on batched constitutive updates. Instead of evaluating the stress for one strain input at a time, we can evaluate many inputs simultaneously — for example, one per quadrature point in a finite element assembly loop.
This is achieved with jax.vmap, which transforms a function such as:
constitutive_update(material, strain, state, dt)
into a batched version:
batched_constitutive_update(material, strain_batch, state_batch, dt)
that operates efficiently across a whole array of strains or states. For more details, see the Batching across material points demo.
It is also possible to batch through a set of material parameters sharing a common PyTree, see this other demo: Batching across material parameters.
On Automatic Differentiation#
AD is central to jaxmat, most notably for computing the consistent tangent operator[1]. In practice, AD is applied directly to the constitutive_update function.
However, the way the update is implemented matters. Consider implicit systems solved with a Newton method:
If the Newton iterations are written out explicitly, AD will differentiate through all iterations (algorithmic unrolling). This leads to:
unnecessarily large and complex computational graphs,
possible numerical instability due to accumulation of floating-point errors.
A better approach is to use the implicit function theorem (IFT): instead of differentiating through iterations, one solves an auxiliary linear system to obtain the derivative. This yields a more efficient and more accurate evaluation.
In jaxmat, we typically use solvers from optimistix and diffrax (the issue is the same for ODEs) which already implement this implicit differentiation technique.