Previously for native serialization we could only support polymorphic_shapes
where the specification was a simple dimension variable. E.g., we could not
handle a specification where `polymorphic_shapes="2*b"` because there was
no way to recover the value of `b` from the actual shape. (For non-native
serialization we were supporting some limited equation solving.)
The above is important, e.g., for the gradient of functions like
`jnp.concatenate([x, x])`, where the output shape if `2 *b`.
This is possible because in #15258 we have brought the computation
of the dimension variables into jax_export.
What we do here is to even out the support for native serialization to have
the same power as the non-native one. We do this by reusing the
same `shape_poly.prepare_dim_var_env` that we use for non-native
serialization.
After we land this, we will refactor the shape environment to be cleaner.
pytype cannot tell from the type signature that unique() returns an Array, not a tuple. Add a cast to help it along.
It's possible that a future use of @overload on the definition of jnp.unique might help.
PiperOrigin-RevId: 520389675
I forgot that the default setting is actually in jaxlib:
fbe9a80fdb/xla/python/xla_client.py (L135)
To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
Currently, JAX native serialization produces a module whose main function
takes additional arguments for the values of the dimension variables. These
values are then resolved in the XlaCallModule based on a dim_args_spec
parameter.
We move the code that computes the dimension variables from XlaCallModule to
jax_export following pretty much the same technique. This simplifies
XlaCallModule and especially its API (the dim_args_spec). So far this
is just a refactoring with no semantic changes, but this will allow us
to improve the support for dimension variables that occur in linear
polynomials, e.g., "2*b" rather than just "b".
JAX will aggressively drop module input arguments if they are not
used. This can interfere with shape polymorphism, because it may
result in dropping arguments from which we need to derive the
values of shape variables.
We fix this for now by disabling dropping arguments if there
are dimension variables in the arguments shapes. A more precise
technique would be to force keeping only of arguments that we
need for deriving the dimension variables. However, that would be
a much more involved change, for an uncertain benefit.
We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization.
PiperOrigin-RevId: 519649827
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.