mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

Most of the functionality is for the JAX native serialization case. This relies on newly added functionality to xla_extension.refine_polymorphic_shapes that handles custom calls @static_assertion. As a beneficial side-effect now we get shape constraint checking for jax2tf graph serialization when the resulting function is executed in graph mode.