George Necula ec8b855fa1 [shape_poly] Add a polymorphic shape refinement MLIR pass accessible to JAX Python.
At the moment we can run the StableHLO module lowered by jax2tf
with polymorphic shapes only with jax2tf, because the tf.XlaCallModule op has the
necessary shape refinement logic (which is necessary to legalize
the StableHLO module with dynamic shapes to MHLO). Here we
expose the shape refinement MLIR transformation to JAX Python.

For now this is used only in a test in jax_export_test.py.

PiperOrigin-RevId: 537485288
2023-06-02 21:49:20 -07:00
..
2023-05-23 17:33:50 -07:00
2023-05-23 14:00:34 -07:00
2023-04-22 11:28:03 -07:00
2023-02-01 12:49:06 -08:00
2021-01-13 10:26:35 -05:00
2023-05-22 11:35:10 -07:00