In presence of ordered effects JAX lowering produces a main
function that takes token
inputs and returns token outputs. Previously, when exporting
such a module, we would wrap the main function with a function
that does not use tokens on inputs and outputs. With this
change we actually leave the token inputs and outputs and
rely on consumers of the exported function to know how to
invoke a function with tokens.
Due to the fact that PJRT does not support passing tokens
as input and output to the top-level function, JAX native
lowering uses dummy bool[0] arrays in lieu of tokens for
the top-level function, and uses stablehlo tokens for the
inner functions. When we export a function for serialization
we want to use stablehlo tokens even at top-level, to enable
calling that function from a larger JAX computation later.
See more details about the calling convention in the
docstring for `export.export`.
We also fix and test multi-platform lowering in presence
of effects.
This introduces serialization version 9, but does not change the
default serialization version. This means that version 9 will not
be used except in tests that specifically override the
serialization version.
This should fix the CI failure with older TPU (the oldest supported TPU should be updated to 20230912 as well).
Tested with:
```
pip install --pre libtpu-nightly==0.1.dev.20230912 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
ENABLE_PJRT_COMPATIBILITY=true python -c "import jax; print(jax.devices());"
ENABLE_PJRT_COMPATIBILITY=true python tests/api_test.py
```
PiperOrigin-RevId: 575268243
In presence of shape polymorphism and multi-platorm lowering
we pass the global values for the dimension variables and the
platform index to all inner functions. At the moment, prior to
compilation we run a shape refinement pass to infer which of
the arguments of a function carry such global values.
This inference can yield false positives, e.g., when a
user-defined function is called with a constant int32 as the first argument.
With this change we do not need to infer anymore the arguments
that carry global constants. This is in preparation for a more
reliable implementation of shape refinement.
Previously we declared the lowering rule for call_exported to be
platform specific. This was correct, but in the case when the
caller function is lowered itself for multiple platforms this results
in multiple copies of the inner called Exported. Now instead we
make the call_exported rule be platform independent and make it
compute the platform index for the called module based on the
platform index in the caller module. This results in a single
copy of the HLO for the called module in the output.
This attribute was used to support shape polymorphism in versions
up to and including version 4. Starting on March 28th 2023 with
JAX version 0.4.6 we stopped using this attribute. We are now
beyond the 6 month backward compatibility version and we drop
support for this attribute.
We also increase the minimum supported serialization version to 5.
See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
PiperOrigin-RevId: 574450204
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.
Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
This ended up accidentally setting up filters for jax_triton. This change additionally
adds an opt-in mechanism for paths, that overrides exclusions. We use this to avoid
treating pallas ops implementations as JAX-internal.
PiperOrigin-RevId: 574167963