CPU Setup of JAX static fields
The setup phase builds an immutable PyTree representing the complete scene graph. This PyTree contains:
- Static data: precomputed mesh topology, finite element data, mappings, and constraints (frozen, GPU-constant)
- Dynamic data: DOFs, solver state, and learnable parameters (traced by JAX, updated each step)
The immutability of the PyTree, combined with clearly identified dynamic leaves and JAX-compatible operations, enables XLA to fully analyze data dependencies and generate optimized GPU kernels. All static data is transferred once to the GPU and remains constant throughout the simulation, while dynamic leaves are efficiently updated at each time step.
Setup phase (CPU)
sequenceDiagram
autonumber
actor User as User Script
participant Builder as SceneBuilder
participant Node as SceneNode
participant Mesh as MeshTopology
participant FE as FEData
participant Graph as GraphData
participant Map as Mapping
participant BC as ProjectiveBC
participant Field as PhysicalField
participant Tree as PyTree
note over User,Tree: Goal: Build immutable PyTree scene graph with all static data precomputed on CPU
User->>Builder: create scene
rect rgba(227,242,253,0.25)
note over Builder,Node: A) Build scene graph hierarchy
Builder->>Builder: create root
loop for each node in scene description
Builder->>Node: instantiate SolverNode/MappedNode
Builder->>Node: set parent/children relationships
Node-->>Builder: node created
end
end
rect rgba(232,245,233,0.25)
note over Builder,FE: B) Precompute mesh & topology data
loop for each SceneNode
Builder->>Node: setup topology
Node->>Mesh: load mesh
Mesh->>Mesh: precompute boundary indices
Mesh->>FE: precompute FE data
FE->>FE: compute jacobian_det at integration points
FE->>FE: compute shape_fn_derivatives
FE-->>Mesh: FiniteElementData ready
Mesh->>Graph: build graph adjacency
Graph->>Graph: compute node-node/edge connectivity
Graph-->>Mesh: GraphData ready
Mesh-->>Node: MeshTopology ready
Node-->>Builder: mesh attached
end
end
rect rgba(255,243,224,0.25)
note over Builder,Map: C) Precompute mappings
loop for each MappedNode
Builder->>Node: setup mapping
Node->>Map: compute mapping to parent
Map->>Map: identify parent_node_indices
Map->>Map: compute weights
Map-->>Node: Mapping ready
Node-->>Builder: mapping attached
end
end
rect rgba(255,235,238,0.25)
note over Builder,BC: D) Build projective constraints
loop for each SolverNode
Builder->>Node: setup boundary conditions
Node->>BC: create projective constraints
BC->>Mesh: query boundary nodes from mesh
Mesh-->>BC: point_data['bc']
BC->>BC: build mask_free_dofs
BC->>BC: set fixed_values
BC-->>Node: ProjectiveBC ready
Node-->>Builder: bc attached
end
end
rect rgba(237,231,246,0.25)
note over Builder,Field: E) Initialize physical fields
loop for each SceneNode
Builder->>Node: setup fields
Node->>Field: create PhysicalField/ConstraintField
Field->>Mesh: precompute field data from mesh
Mesh-->>Field: point_data / cell_data
Field->>Field: initialize static parameters
Field->>Field: allocate dynamic parameters (θ) as zeros
Field-->>Node: PhysicalField ready
Node-->>Builder: fields attached
end
end
rect rgba(255,248,225,0.25)
note over Builder,Tree: F) Initialize dynamic state
loop for each SceneNode
Builder->>Node: create NodeDOFs (u=u0, v=0, a=0)
end
loop for each SolverNode
Builder->>Node: create SolverState (time=0, step_id=0)
end
end
Builder->>Tree: validate PyTree structure
Tree->>Tree: check all static leaves are frozen
Tree->>Tree: check all dynamic leaves are jnp.ndarray
Tree->>Tree: verify no circular references
Tree-->>Builder: PyTree valid
Builder-->>User: return immutable scene PyTree
note over User,Tree: Scene ready for JIT compilation & GPU execution