Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.
I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.
Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
This is a follow up from #25640 that enabled lowering with
AbstractMesh.
This required adding `num_devices` to `lowering.compiler_args`
because in presence of an AbstractMesh the device_assignment
is not accurate.
With jax.experimental.export gone we can now do some cleanup in the export module.
In particular we remove the `export.args_spec` API, and the `lowering_platforms` arg for `export.export`. These were deprecated in June 2024.
PiperOrigin-RevId: 692398132
This fixes several bugs in presence of equality constraints where
the left-hand side is just a dimension variable.
First, such constraints were not applied when parsing variables.
Now, with a constraint `a == b` when we parse "a" we obtain `b`.
Second, when we evaluate symbolic dimensions that contain
dimension variables that are constrained to be equal to something
else, we may fail to find the dimension variable in the environment
because the environment construction has applied the constraints.
We fix this by looking up the unknown dimension variable in
the equality constraints.
Fixes: #23437Fixes: #23456
This allows lowering of threefry2x32 for GPU even on a machine without GPUs.
For the next 3 weeks, we only use the new custom call implementation if
we are not in "export" mode, and if we use a new jaxlib.
PiperOrigin-RevId: 647657084
We take the opportunity of a new jax.export package to rename some
of the API entry points:
* `Exported.uses_shape_polymorphism` is renamed to `Exported.uses_global_constants`
because this is more accurate. The dimension variables are global
constants, but so is the platform index. And we need to run
global constant propagation and shape refinement for all of these.
* We rename "serialization version" with "calling convention version".
Hence we now have `Exported.calling_convention_version`,
and the configuration flag is renamed from `--jax-serialization-version`
to `--jax-export-calling-convention-version`. Also,
`jax.export.minimum_supported_serialization_version` is now
`jax.export.minimum_supported_calling_convention_version`.
* We rename `lowering_platforms` to `platforms` both as a field
of `Exported` and as the kwarg to `export.export`.
* We rename `jax.export.default_lowering_platform` to `jax.export.default_export_version`.