JAX Stack Integration
This page details how SOFAx integrates the JAX ecosystem libraries—Equinox, Optimistix, Lineax, and Diffrax—to handle model definition, linear/nonlinear solves, and time integration. This modular approach allows replacing or upgrading specific components (e.g., solvers) without rewriting the core physics logic.
1. Lineax Integration: Linear Solvers
Lineax handles the solution of linear systems arising from the linearization of the physical residual.
The Linear Problem
In each Newton step, we solve a linear system of the form:
where:
- \(x\) is the state vector (DOFs).
- \(r(x)\) is the residual vector.
- \(J(x) = \frac{\partial r}{\partial x}\) is the Jacobian operator.
- \(\Delta x\) is the update direction.
Linear Operator Implementation
We do not assemble the matrix \(J(x)\) explicitly. Instead, we define a custom linear operator by wrapping lineax.AbstractLinearOperator.
- Matrix-Vector Product (
mv): The action \(J \cdot v\) is computed via JAX's Forward-Mode AD (JVP) on the residual function:def mv(self, v): _, jvp_val = jax.jvp(self.residual_fn, (self.primal_x,), (v,)) return jvp_val - Structure: The operator defines
shape,dtype, andtranspose(via VJP) to satisfy the Lineax interface. - Preconditioning: A preconditioner \(P^{-1} \approx J^{-1}\) can be provided as a separate
AbstractLinearOperatorto accelerate Krylov subspace methods (e.g., GMRES, CG).
2. Optimistix Integration: Nonlinear Solvers
Optimistix provides high-level nonlinear solvers (Newton-Raphson, Levenberg-Marquardt) that drive the simulation step.
The Nonlinear Problem
Optimistix sees the physics as a root-finding problem:
where \(f(x)\) corresponds to the global residual of the simulation scene.
Solver Interface
- Solve Variable: The variable \(x\) typically includes the full dynamic state (positions, velocities, or accelerations depending on the formulation) and Lagrange multipliers.
- Adapter Function: We wrap the core SOFAx residual into a closure compatible with Optimistix:
def root_fn(x, args): params, t, dt, ctx = args return residual_fn(x, params, t, dt, ctx)
Convergence & Differentiation
- Convergence Criteria: Controlled via residual norm relative to tolerance (
rtol,atol). - Differentiation: Gradients through the solve are handled via Implicit Differentiation (provided by Optimistix), allowing backpropagation through the equilibrium state without unrolling the solver iterations. This is crucial for efficient learning and inverse problems.
3. Diffrax Integration: Time Stepping
Diffrax manages the time evolution of the system, treating the simulation as a differential equation solve.
ODE/DAE Formulation
We formulate the simulation as a state evolution problem:
- Explicit ODE: \(x' = g(x, t)\) for explicit schemes.
- Implicit ODE/DAE: \(F(x_{n+1}, x_n, t_{n+1}, \Delta t) = 0\) for implicit integration.
Solver Mapping
- Explicit Methods: Map directly to
diffrax.diffeqsolvewith a vector fieldvf(t, y, args) -> dy/dt. - Implicit Methods:
Advanced Features
- Event Handling: Used for detecting discrete events like contact impacts, enabling precise handling of non-smooth dynamics.
- Adjoints: Diffrax supports
adjointmethods for backpropagation through time, offering memory-efficient gradient computation compared to direct backpropagation (BPTT).
4. Equinox Integration: Model Components
Equinox provides the building blocks for neural networks and parameter management within the simulation.
Model Components
Equinox modules (eqx.Module) are used to define:
- Learned Constitutive Laws: Neural networks replacing classical stress-strain relationships.
- Correction Terms: Additive terms in the residual to correct model errors.
- Operators: Parametric representations of mass or damping matrices.
These modules live inside the ResidualFn or specific physics operators.
JAX Transformations
eqx.filter_jit: Used to compile the step functions, handling the distinction between static configuration (mesh topology) and dynamic state (DOFs).eqx.filter_grad: Computes gradients with respect to trainable parameters (weights, material properties) while ignoring static data.
Serialization
State and parameters are serialized as PyTrees. This ensures reproducibility and allows saving/loading trained models or simulation checkpoints seamlessly.