Skip to content

CPU Setup of JAX static fields

The setup phase builds an immutable PyTree representing the complete scene graph. This PyTree contains:

  • Static data: precomputed mesh topology, finite element data, mappings, and constraints (frozen, GPU-constant)
  • Dynamic data: DOFs, solver state, and learnable parameters (traced by JAX, updated each step)

The immutability of the PyTree, combined with clearly identified dynamic leaves and JAX-compatible operations, enables XLA to fully analyze data dependencies and generate optimized GPU kernels. All static data is transferred once to the GPU and remains constant throughout the simulation, while dynamic leaves are efficiently updated at each time step.

Setup phase (CPU)
sequenceDiagram autonumber actor User as User Script participant Builder as SceneBuilder participant Node as SceneNode participant Mesh as MeshTopology participant FE as FEData participant Graph as GraphData participant Map as Mapping participant BC as ProjectiveBC participant Field as PhysicalField participant Tree as PyTree note over User,Tree: Goal: Build immutable PyTree scene graph with all static data precomputed on CPU User->>Builder: create scene rect rgba(227,242,253,0.25) note over Builder,Node: A) Build scene graph hierarchy Builder->>Builder: create root loop for each node in scene description Builder->>Node: instantiate SolverNode/MappedNode Builder->>Node: set parent/children relationships Node-->>Builder: node created end end rect rgba(232,245,233,0.25) note over Builder,FE: B) Precompute mesh & topology data loop for each SceneNode Builder->>Node: setup topology Node->>Mesh: load mesh Mesh->>Mesh: precompute boundary indices Mesh->>FE: precompute FE data FE->>FE: compute jacobian_det at integration points FE->>FE: compute shape_fn_derivatives FE-->>Mesh: FiniteElementData ready Mesh->>Graph: build graph adjacency Graph->>Graph: compute node-node/edge connectivity Graph-->>Mesh: GraphData ready Mesh-->>Node: MeshTopology ready Node-->>Builder: mesh attached end end rect rgba(255,243,224,0.25) note over Builder,Map: C) Precompute mappings loop for each MappedNode Builder->>Node: setup mapping Node->>Map: compute mapping to parent Map->>Map: identify parent_node_indices Map->>Map: compute weights Map-->>Node: Mapping ready Node-->>Builder: mapping attached end end rect rgba(255,235,238,0.25) note over Builder,BC: D) Build projective constraints loop for each SolverNode Builder->>Node: setup boundary conditions Node->>BC: create projective constraints BC->>Mesh: query boundary nodes from mesh Mesh-->>BC: point_data['bc'] BC->>BC: build mask_free_dofs BC->>BC: set fixed_values BC-->>Node: ProjectiveBC ready Node-->>Builder: bc attached end end rect rgba(237,231,246,0.25) note over Builder,Field: E) Initialize physical fields loop for each SceneNode Builder->>Node: setup fields Node->>Field: create PhysicalField/ConstraintField Field->>Mesh: precompute field data from mesh Mesh-->>Field: point_data / cell_data Field->>Field: initialize static parameters Field->>Field: allocate dynamic parameters (θ) as zeros Field-->>Node: PhysicalField ready Node-->>Builder: fields attached end end rect rgba(255,248,225,0.25) note over Builder,Tree: F) Initialize dynamic state loop for each SceneNode Builder->>Node: create NodeDOFs (u=u0, v=0, a=0) end loop for each SolverNode Builder->>Node: create SolverState (time=0, step_id=0) end end Builder->>Tree: validate PyTree structure Tree->>Tree: check all static leaves are frozen Tree->>Tree: check all dynamic leaves are jnp.ndarray Tree->>Tree: verify no circular references Tree-->>Builder: PyTree valid Builder-->>User: return immutable scene PyTree note over User,Tree: Scene ready for JIT compilation & GPU execution