This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.
First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.
Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.
Fixes: #23437Fixes: #23456
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.