mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 21:16:05 +00:00

Previously for native serialization we could only support polymorphic_shapes where the specification was a simple dimension variable. E.g., we could not handle a specification where `polymorphic_shapes="2*b"` because there was no way to recover the value of `b` from the actual shape. (For non-native serialization we were supporting some limited equation solving.) The above is important, e.g., for the gradient of functions like `jnp.concatenate([x, x])`, where the output shape if `2 *b`. This is possible because in #15258 we have brought the computation of the dimension variables into jax_export. What we do here is to even out the support for native serialization to have the same power as the non-native one. We do this by reusing the same `shape_poly.prepare_dim_var_env` that we use for non-native serialization. After we land this, we will refactor the shape environment to be cleaner.