Composable and differentiable material models#
Every material model in jaxmat inherits from equinox.Module.
An eqx.Module is a convenient extension of standard JAX PyTrees—nested data structures composed of tuples, lists, dictionaries, arrays, and other PyTrees—with the added benefit of behaving like lightweight Python classes.
Equinox modules are:
JAX-compatible (registered as PyTree nodes),
immutable (frozen dataclasses), and
composable and differentiable (supporting nesting of submodules).
In effect, each material model is a structured container of differentiable parameters.
Note
State variables are also represented as equinox.Module. Other JAX-based packages like diffrax or optimistix also use them for representing solvers for instance.
Hierarchical model composition#
While internal state variables (\(\boldsymbol{\alpha}\)) and material parameters (\(\boldsymbol{\theta}\)) could in principle be flattened into a single large vector, in practice they are organized hierarchically into modules and submodules.
For example:
An elastoplastic model may be represented by a parent module containing:
an elastic submodule, and
a plastic submodule.
The plastic submodule may itself include submodules for the yield surface, hardening law, and flow rule.
This modular structure promotes both clarity and reusability—complex constitutive models can be built from simple, well-defined components.
Benefits of using Equinox PyTrees#
equinox.Moduleinstances remain valid PyTrees, so they can be batched, mapped, or differentiated seamlessly using JAX transformations.Functions such as
jax.vmaporjax.gradcan operate over the entire module hierarchy without special handling.When fine-grained edits are needed (e.g. replacing a single subcomponent), standard PyTree utilities (like
jax.tree_maporoptax.tree_utils) can be used for selective modification.
For advanced manipulation, see the Equinox documentation.
Common Equinox patterns in jaxmat#
Automatic conversion to JAX arrays
Many model attributes represent scalar or tensor-valued material parameters. These must be stored asjax.Arrayobjects; plain Python floats ornumpy.ndarraywill not participate in JAX transformations or device placement.
To enforce this automatically, we use:x: float = eqx.field(converter=jnp.asarray)
This ensures the input (float, list, or NumPy array) is converted to a JAX-compatible array.
Attention
Beware of the differences between weakly and strongly typed objects in JAX i.e.
jnp.asarray(0.0)is aweak_f64whereasjnp.asarray(0.0, dtype=jnp.float64)is af64.Declarative defaults
Default attribute values can be defined directly:
p: Tensor2 = eqx.field(default=jnp.float64(0.0))
or can be combined with a converter:
p: Tensor2 = eqx.field(default=0.0, converter=jnp.asarray)
This avoids writing explicit
__init__methods while keeping the model declarative.Attention
The use of
defaultresults in immutable default arguments stored as class attributed. As a result, the object is shared among different instances. Usedefault_factoryinstead to have mutable default arguments, e.g.:F: Tensor2 = eqx.field(default_factory=lambda: Tensor2.identity())
Safe JIT compilation with
eqx.filter_jit
We frequently wrap key methods (e.g. constitutive updates) with:@eqx.filter_jit
This decorator automatically filters out non-JAX types (such as solvers, configuration objects, or static metadata) from the JIT trace.
As a result, entireequinox.Moduleinstances—potentially containing both static and dynamic fields—can be passed directly into JAX-compiled functions without causing tracing errors.