rocm_jax/jax/_src/interpreters
George Necula e643f98558 [shape_poly] Reimplement the shape constraint checking using shape assertions.
Most of the functionality is for the JAX native serialization case.
This relies on newly added functionality to xla_extension.refine_polymorphic_shapes
that handles custom calls @static_assertion.

As a beneficial side-effect now we get shape constraint checking for jax2tf
graph serialization when the resulting function is executed in graph mode.
2023-07-19 09:56:33 +03:00
..
2023-06-23 15:12:14 -07:00