Skip to content

JAX Stack Integration

This page details how SOFAx integrates the JAX ecosystem libraries—Equinox, Optimistix, Lineax, and Diffrax—to handle model definition, linear/nonlinear solves, and time integration. This modular approach allows replacing or upgrading specific components (e.g., solvers) without rewriting the core physics logic.


1. Lineax Integration: Linear Solvers

Lineax handles the solution of linear systems arising from the linearization of the physical residual.

The Linear Problem

In each Newton step, we solve a linear system of the form:

\[ J(x) \, \Delta x = -r(x) \]

where:

  • \(x\) is the state vector (DOFs).
  • \(r(x)\) is the residual vector.
  • \(J(x) = \frac{\partial r}{\partial x}\) is the Jacobian operator.
  • \(\Delta x\) is the update direction.

Linear Operator Implementation

We do not assemble the matrix \(J(x)\) explicitly. Instead, we define a custom linear operator by wrapping lineax.AbstractLinearOperator.

  • Matrix-Vector Product (mv): The action \(J \cdot v\) is computed via JAX's Forward-Mode AD (JVP) on the residual function:
    def mv(self, v):
        _, jvp_val = jax.jvp(self.residual_fn, (self.primal_x,), (v,))
        return jvp_val
    
  • Structure: The operator defines shape, dtype, and transpose (via VJP) to satisfy the Lineax interface.
  • Preconditioning: A preconditioner \(P^{-1} \approx J^{-1}\) can be provided as a separate AbstractLinearOperator to accelerate Krylov subspace methods (e.g., GMRES, CG).

2. Optimistix Integration: Nonlinear Solvers

Optimistix provides high-level nonlinear solvers (Newton-Raphson, Levenberg-Marquardt) that drive the simulation step.

The Nonlinear Problem

Optimistix sees the physics as a root-finding problem:

\[ f(x) = 0 \]

where \(f(x)\) corresponds to the global residual of the simulation scene.

Solver Interface

  • Solve Variable: The variable \(x\) typically includes the full dynamic state (positions, velocities, or accelerations depending on the formulation) and Lagrange multipliers.
  • Adapter Function: We wrap the core SOFAx residual into a closure compatible with Optimistix:
    def root_fn(x, args):
        params, t, dt, ctx = args
        return residual_fn(x, params, t, dt, ctx)
    

Convergence & Differentiation

  • Convergence Criteria: Controlled via residual norm relative to tolerance (rtol, atol).
  • Differentiation: Gradients through the solve are handled via Implicit Differentiation (provided by Optimistix), allowing backpropagation through the equilibrium state without unrolling the solver iterations. This is crucial for efficient learning and inverse problems.

3. Diffrax Integration: Time Stepping

Diffrax manages the time evolution of the system, treating the simulation as a differential equation solve.

ODE/DAE Formulation

We formulate the simulation as a state evolution problem:

  • Explicit ODE: \(x' = g(x, t)\) for explicit schemes.
  • Implicit ODE/DAE: \(F(x_{n+1}, x_n, t_{n+1}, \Delta t) = 0\) for implicit integration.

Solver Mapping

  • Explicit Methods: Map directly to diffrax.diffeqsolve with a vector field vf(t, y, args) -> dy/dt.
  • Implicit Methods:
    • We utilize Diffrax's implicit solvers (e.g., Kvaerno, KenCarp) when applicable.
    • For custom mechanical integrators (e.g., variational integrators), we may implement a custom "stepper" that calls Optimistix internally for the nonlinear solve at each step.

Advanced Features

  • Event Handling: Used for detecting discrete events like contact impacts, enabling precise handling of non-smooth dynamics.
  • Adjoints: Diffrax supports adjoint methods for backpropagation through time, offering memory-efficient gradient computation compared to direct backpropagation (BPTT).

4. Equinox Integration: Model Components

Equinox provides the building blocks for neural networks and parameter management within the simulation.

Model Components

Equinox modules (eqx.Module) are used to define:

  • Learned Constitutive Laws: Neural networks replacing classical stress-strain relationships.
  • Correction Terms: Additive terms in the residual to correct model errors.
  • Operators: Parametric representations of mass or damping matrices.

These modules live inside the ResidualFn or specific physics operators.

JAX Transformations

  • eqx.filter_jit: Used to compile the step functions, handling the distinction between static configuration (mesh topology) and dynamic state (DOFs).
  • eqx.filter_grad: Computes gradients with respect to trainable parameters (weights, material properties) while ignoring static data.

Serialization

State and parameters are serialized as PyTrees. This ensures reproducibility and allows saving/loading trained models or simulation checkpoints seamlessly.