[export] Add documentation for jax.export

This commit is contained in:
George Necula 2024-05-31 15:14:09 +03:00
parent ad9f35ae53
commit 105cc9a103
12 changed files with 755 additions and 133 deletions

View File

@ -136,6 +136,7 @@ jobs:
- name: Test documentation
env:
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
JAX_TRACEBACK_FILTERING: "off"
JAX_ARRAY: 1
PY_COLORS: 1
run: |

View File

@ -35,8 +35,6 @@ way. An example:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
@ -66,6 +64,10 @@ Array(10, dtype=int32, weak_type=True)
```
Note that the lowered objects can be used only in the same process
in which they were lowered. For exporting use cases,
see the {ref}`export` APIs.
See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.
@ -139,11 +141,16 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 :
return %0 : tensor<i32>
}
}
>>> lowered_with_x.compile()(5)
Array(19, dtype=int32, weak_type=True)
```
The result of `lower` is not safe to serialize directly for use
in a different process.
See {ref}`export` for additional APIs for this purpose.
Note that `lower` here takes two arguments as usual, but the subsequent compiled
function accepts only the remaining non-static second argument. The static first
argument (value 7) is taken as a constant at lowering time and built into the
@ -194,7 +201,7 @@ Array([[ 1., 5., 9.],
>>> jax.vmap(g_aot)(zs) # doctest: +SKIP
Traceback (most recent call last):
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>
```

633
docs/export/export.md Normal file
View File

@ -0,0 +1,633 @@
# Exporting and serializing staged-out computations
The {ref}`ahead-of-time-lowering` APIs produce
objects that can be used for debugging or for compilation and
execution in the same process.
Sometimes you want to serialize a lowered JAX function for
compilation and execution in a separate process, perhaps
at a later time. This would allow you to:
* compile and execute the function in another process or machine
without requiring access to the JAX program,
and without having to repeat the staging-out and lowering, e.g.,
in an inference system.
* trace and lower a function on a machine that does not have access
to the accelerator for which you want to later compile and execute
the function.
* archive a snapshot of a JAX function, e.g., to be able to
reproduce later your results. **Note:** check out the [compatibility
guarantees](#compatibility-guarantees) for this use case.
Here is an example:
```python
>>> import re
>>> import numpy as np
>>> import jax
>>> from jax import export
>>> def f(x): return 2 * x * x
>>> exported: export.Exported = export.export(jax.jit(f))(
... jax.ShapeDtypeStruct((), np.float32))
>>> # You can inspect the Exported object
>>> exported.fun_name
'f'
>>> exported.in_avals
(ShapedArray(float32[]),)
>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()
>>> # The serialized function can later be rehydrated and called from
>>> # another JAX computation, possibly in another process.
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
>>> rehydrated_exp.in_avals
(ShapedArray(float32[]),)
>>> def callee(y):
... return 3. * rehydrated_exp.call(y * 4.)
>>> callee(1.)
Array(96., dtype=float32)
```
Serialization is broken down into two stages:
1. exporting to produce an {class}`jax.export.Exported` object that contains
the StableHLO for the lowered function along with the metadata necessary to
call it from another JAX function. We have plans to add code to generate
`Exported` objects from TensorFlow, and to use `Exported` objects from
TensorFlow and PyTorch.
2. the actual serialization to a byte array using the flatbuffers format.
See {ref}`jax2tf` for
an alternative serialization to TensorFlow graph that can be used
for interoperation with TensorFlow.
## Support for reverse-mode AD
Serialization can optionally support higher-order reverse-mode AD. This is done
by serializing the {func}`jax.vjp` of the primal function along with the primal function,
up to a user-specified order (default is 0, meaning that the rehydrated
function cannot be differentiated):
```python
>>> import jax
>>> from jax import export
>>> from typing import Callable
>>> def f(x): return 7 * x * x * x
>>> # Serialize 3 levels of VJP along with the primal function
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
>>> rehydrated_f: Callable = export.deserialize(blob).call
>>> rehydrated_f(0.1) # 7 * 0.1^3
Array(0.007, dtype=float32)
>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2
Array(0.21000001, dtype=float32)
>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1
Array(4.2, dtype=float32)
>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2
Array(42., dtype=float32)
>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: No VJP is available
```
Note that the VJP function is computed lazily while serializing,
when the JAX program is still available.
This means that it respects all features of JAX VJP,
e.g., {func}`jax.custom_vjp` and {func}`jax.remat`.
Note that the rehydrated function does not support any other
transformations, e.g., forward-mode AD (jvp), or {func}`jax.vmap`.
## Compatibility guarantees
You should not use the raw StableHLO that is obtained from just lowering
(`jax.jit(f).lower(1.).compiler_ir()`)
for archival and for compilation in another process, for several reasons.
First, the compilation may use a different version of the compiler, supporting a
different version of StableHLO. The {class}`jax.export` module takes
care of this by using the
[portable-artifact feature of StableHLO](https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md)
to deal with the possible evolution of the StableHLO opset.
### Compatibility guarantees for custom calls
Second, the raw StableHLO may contain custom calls referencing C++
functions.
JAX uses custom calls for lowering of a small number of primitives,
e.g., linear algebra primitives, sharding annotations, or Pallas kernels.
These do not fall under the compatibility guarantees for StableHLO.
The C++ implementations of these functions change rarely, but they can change.
`jax.export` makes the following export compatibility guarantees:
A JAX exported artifact can be compiled and executed by a compiler and
JAX runtime system that are:
* **up to 6 months newer** than the version of JAX used for exporting
(we say that JAX export offers **6 months backward compatibility**).
This is useful if we want to archive the exported artifact to be compiled and executed later.
* **up to 3 weeks older** than the version of JAX used for exporting
(we say that JAX export offers **3 weeks forward compatibility**).
This is useful if we want to compile and run an exported artifact with a
consumer that was built and deployed before the export, e.g.,
an inference system that is already deployed when the exporting is done.
(The particular compatibility window lengths are the same that JAX
[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow).
The terminology “backward compatibility” is from the perspective of the consumer,
e.g., the inference system.)
What **matters is when the exporting and consuming components were built**,
not the time when the exporting and the compilation happen.
For external JAX users, it is
[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned);
what matters is when the jaxlib release was built.
To reduce chances of incompatibility, internal JAX users should:
* **rebuild and redeploy consumer systems as frequently as possible**.
and external users should:
* run the exporting and consumer systems with the same version of jaxlib, whenever possible, and
* export for archival **with the latest released version of jaxlib**.
The compatibility guarantees do not apply if you bypass the `jax.export` APIs
to obtain the StableHLO code.
Only a subset of custom calls are guaranteed stable and have
compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)).
We continuously
add more custom call targets to the allowed list along with backwards
compatibility tests. If you try to serialize
code that invokes other custom call targets you will get an error
during exporting.
If you want to disable this safety check for a specific custom call,
e.g., with target `my_target`, you can add
`export.DisabledSafetyCheck.custom_call("my_target")` to the
`disabled_checks` parameter of the `export` method,
as in the following example:
```python
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> from jax._src.interpreters import mlir
>>> # override the lowering rule for sin to use a custom call `my_new_sin`
>>> _ = mlir.register_lowering(lax.sin_p, lambda ctx, o: mlir.custom_call("my_new_sin", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(lax.sin).lower(1.).compiler_ir())
module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @my_new_sin(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
>>> # If we try to export, we get an error
>>> export.export(jax.jit(lax.sin))(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_sin
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
>>> exp = export.export(
... jax.jit(lax.sin),
... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_sin")])(1.)
```
## Cross-platform and multi-platform export
JAX lowering is platform specific for a small number of JAX primitives.
By default, the code is lowered and exported for the accelerator
present on the exporting machine:
```python
>>> from jax import export
>>> export.default_lowering_platform()
'cpu'
```
There is a safety check that will be raise an error when trying to compile
an `Exported` object on a machine that does not have the accelerator
for which the code was exported.
You can specify explicitly for what platforms the code should be exported.
This allows you to specify a different accelerator than you have
available at export time,
and it even allows you to specify multi-platform lexport to
obtain an `Exported` object that can be compiled and executed
on multiple platforms.
```python
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # You can specify the lowering_platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
>>> # even if the current machine does not have that accelerator.
>>> exp = export.export(jax.jit(lax.cos), lowering_platforms=['tpu'])(1.)
>>> # But you will get an error if you try to compile `exp`
>>> # on a machine that does not have TPUs.
>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: The exported function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
>>> # parameter to `export`, e.g., because you have reasons to believe
>>> # that the code lowered will run adequately on the current
>>> # compilation platform (which is the case for `cos` in this
>>> # example):
>>> exp_unsafe = export.export(jax.jit(lax.cos),
... lowering_platforms=['tpu'],
... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)
>>> exp_unsafe.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)
# and similarly with multi-platform lowering
>>> exp_multi = export.export(jax.jit(lax.cos),
... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.)
>>> exp_multi.call(1.)
Array(0.5403023, dtype=float32, weak_type=True)
```
For multi-platform export, the StableHLO will contain multiple
lowerings but only for those primitives that require it, so the
resulting module size should be only marginally larger than the
size of a module with default export.
As an extreme case, when serializing a module without any
primitives with platform-specific lowering, you will get
the same StableHLO as for the single-plaform export.
```python
>>> import jax
>>> from jax import export
>>> from jax import lax
>>> # A largish function
>>> def f(x):
... for i in range(1000):
... x = jnp.cos(x)
... return x
>>> exp_single = export.export(jax.jit(f))(1.)
>>> len(exp_single.mlir_module_serialized) # doctest: +SKIP
9220
>>> exp_multi = export.export(jax.jit(f),
... lowering_platforms=["cpu", "tpu", "cuda"])(1.)
>>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP
9282
```
## Shape polymorphic export
When used in JIT mode, JAX will trace and lower a function separately
for each combination of input shapes. When exporting, it is possible
in some cases to use dimension variables for some input dimensions
in order to obtain an exported artifact that can be used with multiple
combinations of input shapes.
See the {ref}`shape_poly` documentation.
## Device polymorphic export
An exported artifact may contain sharding annotations for inputs,
outputs and for some intermediates, but these annotations do not refer
directly to the actual physical devices that existed at exporting time.
Instead, the sharding annotations refer to logical devices. This
means that you can compile and run the exported artifacts on different
physical devices that were used for exporting.
```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> # Use the first 4 devices for exporting.
>>> export_devices = jax.local_devices()[:4]
>>> export_mesh = Mesh(export_devices, ("a",))
>>> def f(x):
... return x.T
>>> arg = jnp.arange(8 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
>>> # `exp` knows for how many devices it was exported.
>>> exp.nr_devices
4
>>> # and it knows the shardings for the inputs. These will be applied
>>> # when the exported is called.
>>> exp.in_shardings_hlo
({devices=[4]<=[4]},)
>>> res1 = exp.call(jax.device_put(arg,
... NamedSharding(export_mesh, P("a"))))
>>> # Check out the first 2 shards of the result
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
'device=TFRT_CPU_1 index=(slice(8, 16, None),)']
>>> # We can call `exp` with some other 4 devices and another
>>> # mesh with a different shape, as long as the number of devices is
>>> # the same.
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
>>> res2 = exp.call(jax.device_put(arg,
... NamedSharding(other_mesh, P("b"))))
>>> # Check out the first 2 shards of the result. Notice that the output is
>>> # sharded similarly; this means that the input was resharded according to the
>>> # exp.in_shardings.
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
'device=TFRT_CPU_3 index=(slice(8, 16, None),)']
```
It is an error to try to invoke an exported artifact with a different number
of devices than it was exported for:
```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
... return x.T
>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
```
There are helper functions to shard the inputs for calling an exported
artifacts using a new mesh constructed at the call site:
```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> export_devices = jax.local_devices()
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
>>> def f(x):
... return x.T
>>> arg = jnp.arange(4 * len(export_devices))
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))
>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
>>> res = exp.call(sharded_arg)
```
As a special facility, if a function was exported for 1 device and if it contains no
sharding annotations, then it can be invoked on an argument of the same shape but sharded
on multiple devices, and the compiler will shard the function appropriately:
```python
```python
>>> import jax
>>> from jax import export
>>> from jax.sharding import Mesh, NamedSharding
>>> from jax.sharding import PartitionSpec as P
>>> def f(x):
... return jnp.cos(x)
>>> arg = jnp.arange(4)
>>> exp = export.export(jax.jit(f))(arg)
>>> exp.in_avals
(ShapedArray(int32[4]),)
>>> exp.nr_devices
1
>>> # Prepare the mesh for calling `exp`.
>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))
>>> # Shard the arg according to what `exp` expects.
>>> sharded_arg = jax.device_put(arg,
... NamedSharding(calling_mesh, P("b")))
>>> res = exp.call(sharded_arg)
```
## Module serialization versions
The JAX export support has evolved over time, e.g., to support
effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees))
we maintain a serialization version for each `Exported`.
As of June 2024, all modules are serialized with version 9
(the latest, see [all serialization versions](#serialization-version-numbers)):
```python
>>> from jax import export
>>> exp: export.Exported = export.export(jnp.cos)(1.)
>>> exp.mlir_module_serialization_version
9
```
At any given time, the export APIs may support a range
of serialization versions. You can control which serialization
version to use using the `--jax-serialization-version` flag
or the `JAX_SERIALIZATION_VERSION` environment variable:
```python
>>> from jax import export
>>> (export.minimum_supported_serialization_version, export.maximum_supported_serialization_version)
(9, 9)
>>> from jax._src import config
>>> with config.jax_serialization_version(9):
... exp = export.export(jnp.cos)(1.)
... exp.mlir_module_serialization_version
9
```
We reserve the right to remove support for
generating or consuming serialization versions older than 6 months.
### Module calling convention
The `Exported.mlir_module` has a `main` function that takes an optional first
platform index argument if the module supports multiple platforms
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
to the ordered effects, followed by the kept array
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
The platform index is a i32 or i64 scalar encoding the index of the current
compilation platform into the `lowering_platforms` sequence.
Inner functions use a different calling convention: an optional
platform index argument, optional dimension variable arguments
(scalar tensors of type i32 or i64),
followed by optional token arguments (in presence of ordered effects),
followed by the regular array arguments.
The dimension arguments correspond to the dimension variables appearing in
the `args_avals`, in sorted order of their names.
Consider the lowering of a function with one array argument of type
`f32[w, 2 * h]`, where `w` and `h` are two dimension variables.
Assume that we use multi-platform lowering, and we have
one ordered effect. The `main` function will be as follows:
```
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
token_in: token,
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return token_out, res
}
```
The actual computation is in `_wrapped_jax_export_main`, taking also
the values of `h` and `w` dimension variables.
The signature of the `_wrapped_jax_export_main` is:
```
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?]) -> (stablehlo.token, ...)
```
Prior to serialization version 9 the calling convention for effects was
different: the `main` function does not take or return a token. Instead
the function creates dummy tokens of type `i1[0]` and passes them to the
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
takes dummy tokens of type `i1[0]` and will create internally real
tokens to pass to the inner functions. The inner functions use real
tokens (both before and after serialization version 9)
Also starting with serialization version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
The global constant name may be empty if it is not known.
Some global constant computations use inner functions, e.g., for
`floor_divide`. The arguments of such functions have a `jax.global_constant`
attribute for all attributes, meaning that the result of the function is
also a global constant.
Note that `main` contains a call to `_check_shape_assertions`.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
module. We use a special custom call `@shape_assertion` that takes
a boolean first operand, a string `error_message` attribute that may contain
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
scalar operands corresponding to the format specifiers.
```
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
```
### Serialization version numbers
We list here a history of the serialization version numbers:
* Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
* Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported
anymore.
* Version 3 supports platform checking and multiple platforms.
Used from February 2023. Not supported anymore.
* Version 4 supports StableHLO with compatibility guarantees.
This is the earliest version at the time of the JAX native serialization
launch.
Used in JAX from March 15, 2023 (cl/516885716). Starting with
March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493).
The support for this version was dropped on
October 17th, 2023 (cl/573858283).
* Version 5 adds support for `call_tf_graph`. This is currently used
for some specialized use cases. Used in JAX from May 3rd, 2023
(cl/529106145).
* Version 6 adds support for the `disabled_checks` attribute. This version
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
since June 7th, 2023 and available in JAX since
June 13th, 2023 (JAX 0.4.13).
* Version 7 adds support for `stablehlo.shape_assertion` operations and
for `shape_assertions` specified in `disabled_checks`.
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
since July 12th, 2023 (cl/547482522),
available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
and the default since August 12th, 2023 (JAX 0.4.15).
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
attribute and enables the shape refinement pass only when the
attribute is present. Supported by XlaCallModule since July 21st, 2023
(cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14),
and the default since October 21st, 2023 (JAX 0.4.20).
* Version 9 adds support for effects.
See the docstring for `export.Exported` for the precise calling convention.
In this serialization version we also tag the platform index and the
dimension variables arguments with `jax.global_constant` attributes.
Supported by XlaCallModule since October 27th, 2023,
available in JAX since October 20th, 2023 (JAX 0.4.20),
and the default since February 1st, 2024 (JAX 0.4.24).
This is the only supported version as of 27th of March, 2024.

13
docs/export/index.rst Normal file
View File

@ -0,0 +1,13 @@
.. _export:
Exporting and serialization
=============================
.. toctree::
:caption: Guides
:maxdepth: 2
export
shape_poly
jax2tf

5
docs/export/jax2tf.md Normal file
View File

@ -0,0 +1,5 @@
(jax2tf)=
## Interoperation with TensorFlow
TO DO: move here the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).

View File

@ -0,0 +1,5 @@
(shape_poly)=
## Shape polymorphism
TO DO: populate from the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion).

36
docs/jax.export.rst Normal file
View File

@ -0,0 +1,36 @@
``jax.export`` module
=====================
.. automodule:: jax.export
Classes
-------
.. autosummary::
:toctree: _autosummary
Exported
DisabledSafetyCheck
Functions
---------
.. autosummary::
:toctree: _autosummary
export
deserialize
minimum_supported_serialization_version
maximum_supported_serialization_version
default_lowering_platform
Functions related to shape polymorphism
---------------------------------------
.. autosummary::
:toctree: _autosummary
symbolic_shape
symbolic_args_specs
is_symbolic_dim
SymbolicScope

View File

@ -27,6 +27,7 @@ Subpackages
jax.tree
jax.tree_util
jax.typing
jax.export
jax.extend
jax.example_libraries
jax.experimental

View File

@ -32,6 +32,7 @@ or deployed codebases.
:caption: Run Time
aot
export/index
errors
transfer_guard

View File

@ -63,28 +63,35 @@ Shape = jax._src.core.Shape
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
HloSharding = xla_client.HloSharding
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
# for a description of the different versions.
"""The minimum supported serialization version.
See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions
"""
minimum_supported_serialization_version = 9
"""The maximum supported serialization version.
See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions
"""
maximum_supported_serialization_version = 9
class DisabledSafetyCheck:
"""A safety check should be skipped on (de)serialization.
"""A safety check that should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred to
deserialization. The list of disabled checks is attached to the serialization,
e.g., as a sequence of string attributes to `jax_export.Exported` or of
e.g., as a sequence of string attributes to `jax.export.Exported` or of
`tf.XlaCallModuleOp`.
You can disable more deserialization safety checks by passing
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
When using jax2tf, you can disable more deserialization safety checks
by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
"""
_impl: str
@classmethod
def platform(cls) -> DisabledSafetyCheck:
"""Allows the execution platform to differ from the serialization platform.
"""Allows the compilation platform to differ from the export platform.
Has effect only on deserialization.
"""
@ -102,7 +109,7 @@ class DisabledSafetyCheck:
@classmethod
def shape_assertions(cls) -> DisabledSafetyCheck:
"""A noop. DEPRECATED.
"""DEPRECATED: A noop.
Was used previously to allow invocations with shapes that do not meet the
constraints. Has no effect anymore, shape assertions cannot be disabled.
@ -149,7 +156,7 @@ class Exported:
out_tree: a PyTreeDef describing the result of the lowered JAX function.
out_avals: the flat tuple of output abstract values. May contain dimension
expressions in the shapes, with dimension variables among those in
`in_avals.
`in_avals`.
in_shardings_hlo: the flattened input shardings, a sequence as long
as `in_avals`. `None` means unspecified sharding.
Note that these do not include the mesh or the actual devices used in
@ -162,16 +169,17 @@ class Exported:
into sharding specification that can be used with JAX APIs.
nr_devices: the number of devices that the module has been lowered for.
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
'cuda', 'rocm'. See below for the calling convention for when
'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
for the calling convention for when
there are multiple lowering platforms.
ordered_effects: the ordered effects present in the serialized module.
This is present from serialization version 9. See below for the
calling convention in presence of ordered effects.
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
for the calling convention in presence of ordered effects.
unordered_effects: the unordered effects present in the serialized module.
This is present from serialization version 9.
mlir_module_serialized: the serialized lowered VHLO module.
mlir_module_serialization_version: a version number for the serialized module.
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used.
@ -189,100 +197,7 @@ class Exported:
for each primal output. It returns a tuple with the cotangents
corresponding to the flattened primal inputs.
Calling convention for the exported module (for latest supported version):
The `mlir_module` has a `main` function that takes an optional first
platform index argument if the module supports multiple platforms
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
to the ordered effects, followed by the kept array
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
The platform index is a i32 or i64 scalar encoding the index of the current
compilation platform into the `lowering_platforms` sequence.
Inner functions use a different calling convention: an optional
platform index argument, optional dimension variable arguments
(scalar tensors of type i32 or i64),
followed by optional token arguments (in presence of ordered effects),
followed by the regular array arguments.
The dimension arguments correspond to the dimension variables appearing in
the `args_avals`, in sorted order of their names.
Consider the lowering of a function with one array argument of type "f32[w,
2 * h]", where "w" and "h" are two dimension variables.
Assume that we use multi-platform lowering, and we have
one ordered effect. The `main` function will be as follows:
func public main(
platform_index: i32 {jax.global_constant="_platform_index"},
token_in: token,
arg: f32[?, ?]) {
arg_w = hlo.get_dimension_size(arg, 0)
dim1 = hlo.get_dimension_size(arg, 1)
arg_h = hlo.floordiv(dim1, 2)
call _check_shape_assertions(arg) # See below
token = new_token()
token_out, res = call _wrapped_jax_export_main(platform_index,
arg_h,
arg_w,
token_in,
arg)
return token_out, res
}
The actual computation is in `_wrapped_jax_export_main`, taking also
the values of `h` and `w` dimension variables.
The signature of the `_wrapped_jax_export_main` is:
func private _wrapped_jax_export_main(
platform_index: i32 {jax.global_constant="_platform_index"},
arg_h: i32 {jax.global_constant="h"},
arg_w: i32 {jax.global_constant="w"},
arg_token: stablehlo.token {jax.token=True},
arg: f32[?, ?]) -> (stablehlo.token, ...)
Prior to serialization version 9 the calling convention for effects is
different: the `main` function does not take or return a token. Instead
the function creates dummy tokens of type `i1[0]` and passes them to the
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
takes dummy tokens of type `i1[0]` and will create internally real
tokens to pass to the inner functions. The inner functions use real
tokens (both before and after serialization version 9)
Also starting with serialization version 9, function arguments that contain
the platform index or the dimension variable values have a
`jax.global_constant` string attribute whose value is the name of the
global constant, either `_platform_index` or a dimension variable name.
The global constant name may be empty if it is not known.
Some global constant computations use inner functions, e.g., for
`floor_divide`. The arguments of such functions have a `jax.global_constant`
attribute for all attributes, meaning that the result of the function is
also a global constant.
Note that `main` contains a call to `_check_shape_assertions.
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
have values >= 1. We must check these constraints when we invoke the
module. We use a special custom call `@shape_assertion` that takes
a boolean first operand, a string `error_message` attribute that may contain
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
scalar operands corresponding to the format specifiers.
func private _check_shape_assertions(arg: f32[?, ?]) {
# Check that w is >= 1
arg_w = hlo.get_dimension_size(arg, 0)
custom_call @shape_assertion(arg_w >= 1, arg_w,
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
# Check that dim1 is even
dim1 = hlo.get_dimension_size(arg, 1)
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
# Check that h >= 1
arg_h = hlo.floordiv(dim1, 2)
custom_call @shape_assertion(arg_h >= 1, arg_h,
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
If we `call_exported` with this module we perform these checks
statically (in `call_exported_abstract_eval`).
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module_calling_convention).
"""
fun_name: str
in_tree: tree_util.PyTreeDef
@ -417,7 +332,9 @@ def deserialize(blob: bytearray) -> Exported:
def default_lowering_platform() -> str:
"""Retrieves the default lowering platform for the exporting machine.
"""Retrieves the default lowering platform.
One of: `tpu`, `cpu`, `cuda`, `rocm`.
"""
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
return xb.canonicalize_platform(jax.default_backend())
@ -458,6 +375,7 @@ def export_back_compat(
Note: this function exists only for internal usage by jax2tf and for
backwards compatibility with jax.experimental.export. Use
`jax.export` instead.
See https://jax.readthedocs.io/en/latest/export.html
Args:
fun_jax: the function to lower and serialize.
@ -466,8 +384,8 @@ def export_back_compat(
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained in the
`jax_export.Exported` docstring.
The calling convention for multiple platforms is explained
at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
@ -547,32 +465,33 @@ def export(
"""Exports a JAX function for persistent serialization.
Args:
fun_jit: the function to export. Should be the result of `jit`.
fun_jit: the function to export. Should be the result of `jax.jit`.
lowering_platforms:
Optional sequence containing a subset of 'tpu', 'cpu',
'cuda', 'rocm'. If more than one platform is specified, then
the lowered code takes an argument specifying the platform.
If None, then use the default JAX backend.
The calling convention for multiple platforms is explained in the
`jax_export.Exported` docstring.
disabled_checks: the safety checks to disable. See docstring
of `DisabledSafetyCheck`.
The calling convention for multiple platforms is explained at
https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
disabled_checks: the safety checks to disable. See documentation for
of `jax.export.DisabledSafetyCheck`.
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
or values with `.shape` and `.dtype` attributes, and returns an
`Exported`.
Usage:
>>> from jax import export
>>> exported: export.Exported = export.export(jnp.sin)(
... np.arange(4, dtype=np.float32))
# You can inspect the Exported object
>>>
>>> # You can inspect the Exported object
>>> exported.in_avals
(ShapedArray(float32[4]),)
>>> blob: bytearray = exported.serialize()
# The serialized bytes are safe to use in a separate process
>>>
>>> # The serialized bytes are safe to use in a separate process
>>> rehydrated: export.Exported = export.deserialize(blob)
>>> rehydrated.fun_name
'sin'
@ -776,11 +695,10 @@ def _wrap_main_func(
) -> ir.Module:
"""Wraps the lowered module with a new "main" handling dimension arguments.
See calling convention documentation for `jax_export.Exported`.
See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
Args:
module: the HLO module as obtained from lowering. See the calling convention
for inner functions in `jax_export.Exported`.
module: the HLO module as obtained from lowering.
args_avals_flat: the avals for all the arguments of the lowered function,
which correspond to the array arguments of the `module`.
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
@ -1343,7 +1261,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
for s in exported.in_shardings_hlo + exported.out_shardings_hlo)):
err_msg = "the module contains non-replicated sharding annotations."
if err_msg:
raise NotImplementedError(
raise ValueError(
f"Exported module {exported.fun_name} was lowered for "
f"{exported.nr_devices} devices and is called in a context with "
f"{num_devices} devices. This is disallowed because: {err_msg}"

View File

@ -982,7 +982,7 @@ class SymbolicScope:
Holds the constraints on symbolic expressions.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints)
See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints)
for more details.
Args:
@ -1194,6 +1194,8 @@ def _convertible_to_poly(p: DimSize) -> bool:
return isinstance(p, _DimExpr) or _convertible_to_int(p)
def is_symbolic_dim(p: DimSize) -> bool:
"""Checks if a dimension is symbolic.
"""
return isinstance(p, _DimExpr)
def is_poly_dim(p: DimSize) -> bool:
@ -1856,7 +1858,7 @@ class ShapeConstraints:
def shape_assertions(self, eval: CachingShapeEvaluator) -> None:
"""Computes the shape assertions for the set of constraints.
See jax_export._wrap_main_func docstring.
See jax_export.Exported docstring.
"""
# We want to report the errors in the same order as `check_statically`.
# So, we process them in order, in case some fail statically, and we

View File

@ -965,12 +965,12 @@ class JaxExportTest(jtu.JaxTestCase):
# Test error reporting
with self.assertRaisesRegex(
NotImplementedError,
ValueError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
_ = exp.call(a)
with self.assertRaisesRegex(
NotImplementedError,
ValueError,
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
_ = jax.jit(
@ -1046,7 +1046,7 @@ class JaxExportTest(jtu.JaxTestCase):
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
with self.assertRaisesRegex(
NotImplementedError,
ValueError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):
@ -1092,7 +1092,7 @@ class JaxExportTest(jtu.JaxTestCase):
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
with self.assertRaisesRegex(
NotImplementedError,
ValueError,
"Exported module .* was lowered for 1 devices and is called in a "
f"context with {jax.local_device_count()} devices.* module contains "
"non-replicated sharding annotations"):