diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 566ecaca7..3efdc0c2b 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -136,6 +136,7 @@ jobs: - name: Test documentation env: XLA_FLAGS: "--xla_force_host_platform_device_count=8" + JAX_TRACEBACK_FILTERING: "off" JAX_ARRAY: 1 PY_COLORS: 1 run: | diff --git a/docs/aot.md b/docs/aot.md index e1420f702..8615d7513 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -35,8 +35,6 @@ way. An example: ```python >>> import jax ->>> import jax.numpy as jnp ->>> import numpy as np >>> def f(x, y): return 2 * x + y >>> x, y = 3, 4 @@ -66,6 +64,10 @@ Array(10, dtype=int32, weak_type=True) ``` +Note that the lowered objects can be used only in the same process +in which they were lowered. For exporting use cases, +see the {ref}`export` APIs. + See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. @@ -139,11 +141,16 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : return %0 : tensor } } + >>> lowered_with_x.compile()(5) Array(19, dtype=int32, weak_type=True) ``` +The result of `lower` is not safe to serialize directly for use +in a different process. +See {ref}`export` for additional APIs for this purpose. + Note that `lower` here takes two arguments as usual, but the subsequent compiled function accepts only the remaining non-static second argument. The static first argument (value 7) is taken as a constant at lowering time and built into the @@ -194,7 +201,7 @@ Array([[ 1., 5., 9.], >>> jax.vmap(g_aot)(zs) # doctest: +SKIP Traceback (most recent call last): -TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type . +TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type ``` diff --git a/docs/export/export.md b/docs/export/export.md new file mode 100644 index 000000000..815d88556 --- /dev/null +++ b/docs/export/export.md @@ -0,0 +1,633 @@ +# Exporting and serializing staged-out computations + +The {ref}`ahead-of-time-lowering` APIs produce +objects that can be used for debugging or for compilation and +execution in the same process. +Sometimes you want to serialize a lowered JAX function for +compilation and execution in a separate process, perhaps +at a later time. This would allow you to: + + * compile and execute the function in another process or machine + without requiring access to the JAX program, + and without having to repeat the staging-out and lowering, e.g., + in an inference system. + * trace and lower a function on a machine that does not have access + to the accelerator for which you want to later compile and execute + the function. + * archive a snapshot of a JAX function, e.g., to be able to + reproduce later your results. **Note:** check out the [compatibility + guarantees](#compatibility-guarantees) for this use case. + +Here is an example: + +```python +>>> import re +>>> import numpy as np +>>> import jax +>>> from jax import export + +>>> def f(x): return 2 * x * x + + +>>> exported: export.Exported = export.export(jax.jit(f))( +... jax.ShapeDtypeStruct((), np.float32)) + +>>> # You can inspect the Exported object +>>> exported.fun_name +'f' + +>>> exported.in_avals +(ShapedArray(float32[]),) + +>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"} loc("x")) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + +>>> # And you can serialize the Exported to a bytearray. +>>> serialized: bytearray = exported.serialize() + +>>> # The serialized function can later be rehydrated and called from +>>> # another JAX computation, possibly in another process. +>>> rehydrated_exp: export.Exported = export.deserialize(serialized) +>>> rehydrated_exp.in_avals +(ShapedArray(float32[]),) + +>>> def callee(y): +... return 3. * rehydrated_exp.call(y * 4.) + +>>> callee(1.) +Array(96., dtype=float32) + +``` + +Serialization is broken down into two stages: + 1. exporting to produce an {class}`jax.export.Exported` object that contains + the StableHLO for the lowered function along with the metadata necessary to + call it from another JAX function. We have plans to add code to generate + `Exported` objects from TensorFlow, and to use `Exported` objects from + TensorFlow and PyTorch. + 2. the actual serialization to a byte array using the flatbuffers format. + See {ref}`jax2tf` for + an alternative serialization to TensorFlow graph that can be used + for interoperation with TensorFlow. + +## Support for reverse-mode AD + +Serialization can optionally support higher-order reverse-mode AD. This is done +by serializing the {func}`jax.vjp` of the primal function along with the primal function, +up to a user-specified order (default is 0, meaning that the rehydrated +function cannot be differentiated): + +```python +>>> import jax +>>> from jax import export +>>> from typing import Callable + +>>> def f(x): return 7 * x * x * x + +>>> # Serialize 3 levels of VJP along with the primal function +>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3) +>>> rehydrated_f: Callable = export.deserialize(blob).call + +>>> rehydrated_f(0.1) # 7 * 0.1^3 +Array(0.007, dtype=float32) + +>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2 +Array(0.21000001, dtype=float32) + +>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1 +Array(4.2, dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2 +Array(42., dtype=float32) + +>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: No VJP is available + +``` + +Note that the VJP function is computed lazily while serializing, +when the JAX program is still available. +This means that it respects all features of JAX VJP, +e.g., {func}`jax.custom_vjp` and {func}`jax.remat`. + +Note that the rehydrated function does not support any other +transformations, e.g., forward-mode AD (jvp), or {func}`jax.vmap`. + +## Compatibility guarantees + +You should not use the raw StableHLO that is obtained from just lowering +(`jax.jit(f).lower(1.).compiler_ir()`) +for archival and for compilation in another process, for several reasons. + +First, the compilation may use a different version of the compiler, supporting a +different version of StableHLO. The {class}`jax.export` module takes +care of this by using the +[portable-artifact feature of StableHLO](https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md) +to deal with the possible evolution of the StableHLO opset. + +### Compatibility guarantees for custom calls + +Second, the raw StableHLO may contain custom calls referencing C++ +functions. +JAX uses custom calls for lowering of a small number of primitives, +e.g., linear algebra primitives, sharding annotations, or Pallas kernels. +These do not fall under the compatibility guarantees for StableHLO. +The C++ implementations of these functions change rarely, but they can change. + +`jax.export` makes the following export compatibility guarantees: +A JAX exported artifact can be compiled and executed by a compiler and +JAX runtime system that are: + + * **up to 6 months newer** than the version of JAX used for exporting + (we say that JAX export offers **6 months backward compatibility**). + This is useful if we want to archive the exported artifact to be compiled and executed later. + * **up to 3 weeks older** than the version of JAX used for exporting + (we say that JAX export offers **3 weeks forward compatibility**). + This is useful if we want to compile and run an exported artifact with a + consumer that was built and deployed before the export, e.g., + an inference system that is already deployed when the exporting is done. + +(The particular compatibility window lengths are the same that JAX +[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). +The terminology “backward compatibility” is from the perspective of the consumer, +e.g., the inference system.) + +What **matters is when the exporting and consuming components were built**, +not the time when the exporting and the compilation happen. +For external JAX users, it is +[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned); +what matters is when the jaxlib release was built. + +To reduce chances of incompatibility, internal JAX users should: + * **rebuild and redeploy consumer systems as frequently as possible**. + +and external users should: + * run the exporting and consumer systems with the same version of jaxlib, whenever possible, and + * export for archival **with the latest released version of jaxlib**. + +The compatibility guarantees do not apply if you bypass the `jax.export` APIs +to obtain the StableHLO code. + +Only a subset of custom calls are guaranteed stable and have +compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)). +We continuously +add more custom call targets to the allowed list along with backwards +compatibility tests. If you try to serialize +code that invokes other custom call targets you will get an error +during exporting. + +If you want to disable this safety check for a specific custom call, +e.g., with target `my_target`, you can add +`export.DisabledSafetyCheck.custom_call("my_target")` to the +`disabled_checks` parameter of the `export` method, +as in the following example: + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> from jax._src.interpreters import mlir + +>>> # override the lowering rule for sin to use a custom call `my_new_sin` +>>> _ = mlir.register_lowering(lax.sin_p, lambda ctx, o: mlir.custom_call("my_new_sin", operands=[o], result_types=[o.type]).results) +>>> print(jax.jit(lax.sin).lower(1.).compiler_ir()) +module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor {mhlo.layout_mode = "default"}) -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @my_new_sin(%arg0) {api_version = 2 : i32} : (tensor) -> tensor + return %0 : tensor + } +} + +>>> # If we try to export, we get an error +>>> export.export(jax.jit(lax.sin))(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_sin + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call` +>>> exp = export.export( +... jax.jit(lax.sin), +... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_sin")])(1.) + +``` + +## Cross-platform and multi-platform export + +JAX lowering is platform specific for a small number of JAX primitives. +By default, the code is lowered and exported for the accelerator +present on the exporting machine: + +```python +>>> from jax import export +>>> export.default_lowering_platform() +'cpu' + +``` + +There is a safety check that will be raise an error when trying to compile +an `Exported` object on a machine that does not have the accelerator +for which the code was exported. + +You can specify explicitly for what platforms the code should be exported. +This allows you to specify a different accelerator than you have +available at export time, +and it even allows you to specify multi-platform lexport to +obtain an `Exported` object that can be compiled and executed +on multiple platforms. + + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax + +>>> # You can specify the lowering_platform, e.g., `tpu`, `cpu`, `cuda`, `rocm` +>>> # even if the current machine does not have that accelerator. +>>> exp = export.export(jax.jit(lax.cos), lowering_platforms=['tpu'])(1.) + +>>> # But you will get an error if you try to compile `exp` +>>> # on a machine that does not have TPUs. +>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: The exported function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'. + +>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform` +>>> # parameter to `export`, e.g., because you have reasons to believe +>>> # that the code lowered will run adequately on the current +>>> # compilation platform (which is the case for `cos` in this +>>> # example): +>>> exp_unsafe = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu'], +... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.) + +>>> exp_unsafe.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +# and similarly with multi-platform lowering +>>> exp_multi = export.export(jax.jit(lax.cos), +... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.) +>>> exp_multi.call(1.) +Array(0.5403023, dtype=float32, weak_type=True) + +``` + +For multi-platform export, the StableHLO will contain multiple +lowerings but only for those primitives that require it, so the +resulting module size should be only marginally larger than the +size of a module with default export. +As an extreme case, when serializing a module without any +primitives with platform-specific lowering, you will get +the same StableHLO as for the single-plaform export. + +```python +>>> import jax +>>> from jax import export +>>> from jax import lax +>>> # A largish function +>>> def f(x): +... for i in range(1000): +... x = jnp.cos(x) +... return x + +>>> exp_single = export.export(jax.jit(f))(1.) +>>> len(exp_single.mlir_module_serialized) # doctest: +SKIP +9220 + +>>> exp_multi = export.export(jax.jit(f), +... lowering_platforms=["cpu", "tpu", "cuda"])(1.) +>>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP +9282 + +``` + +## Shape polymorphic export + +When used in JIT mode, JAX will trace and lower a function separately +for each combination of input shapes. When exporting, it is possible +in some cases to use dimension variables for some input dimensions +in order to obtain an exported artifact that can be used with multiple +combinations of input shapes. + +See the {ref}`shape_poly` documentation. + +## Device polymorphic export + +An exported artifact may contain sharding annotations for inputs, +outputs and for some intermediates, but these annotations do not refer +directly to the actual physical devices that existed at exporting time. +Instead, the sharding annotations refer to logical devices. This +means that you can compile and run the exported artifacts on different +physical devices that were used for exporting. + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> # Use the first 4 devices for exporting. +>>> export_devices = jax.local_devices()[:4] +>>> export_mesh = Mesh(export_devices, ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(8 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # `exp` knows for how many devices it was exported. +>>> exp.nr_devices +4 + +>>> # and it knows the shardings for the inputs. These will be applied +>>> # when the exported is called. +>>> exp.in_shardings_hlo +({devices=[4]<=[4]},) + +>>> res1 = exp.call(jax.device_put(arg, +... NamedSharding(export_mesh, P("a")))) + +>>> # Check out the first 2 shards of the result +>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]] +['device=TFRT_CPU_0 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_1 index=(slice(8, 16, None),)'] + +>>> # We can call `exp` with some other 4 devices and another +>>> # mesh with a different shape, as long as the number of devices is +>>> # the same. +>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c")) +>>> res2 = exp.call(jax.device_put(arg, +... NamedSharding(other_mesh, P("b")))) + +>>> # Check out the first 2 shards of the result. Notice that the output is +>>> # sharded similarly; this means that the input was resharded according to the +>>> # exp.in_shardings. +>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]] +['device=TFRT_CPU_2 index=(slice(0, 8, None),)', + 'device=TFRT_CPU_3 index=(slice(8, 16, None),)'] + +``` + +It is an error to try to invoke an exported artifact with a different number +of devices than it was exported for: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): +ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device. + +``` + +There are helper functions to shard the inputs for calling an exported +artifacts using a new mesh constructed at the call site: + +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> export_devices = jax.local_devices() +>>> export_mesh = Mesh(np.array(export_devices), ("a",)) +>>> def f(x): +... return x.T + +>>> arg = jnp.arange(4 * len(export_devices)) +>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg) + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0]) +>>> res = exp.call(sharded_arg) + +``` + +As a special facility, if a function was exported for 1 device and if it contains no +sharding annotations, then it can be invoked on an argument of the same shape but sharded +on multiple devices, and the compiler will shard the function appropriately: + +```python +```python +>>> import jax +>>> from jax import export +>>> from jax.sharding import Mesh, NamedSharding +>>> from jax.sharding import PartitionSpec as P + +>>> def f(x): +... return jnp.cos(x) + +>>> arg = jnp.arange(4) +>>> exp = export.export(jax.jit(f))(arg) +>>> exp.in_avals +(ShapedArray(int32[4]),) + +>>> exp.nr_devices +1 + +>>> # Prepare the mesh for calling `exp`. +>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",)) + +>>> # Shard the arg according to what `exp` expects. +>>> sharded_arg = jax.device_put(arg, +... NamedSharding(calling_mesh, P("b"))) +>>> res = exp.call(sharded_arg) + +``` + +## Module serialization versions + +The JAX export support has evolved over time, e.g., to support +effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees)) +we maintain a serialization version for each `Exported`. +As of June 2024, all modules are serialized with version 9 +(the latest, see [all serialization versions](#serialization-version-numbers)): + +```python +>>> from jax import export +>>> exp: export.Exported = export.export(jnp.cos)(1.) +>>> exp.mlir_module_serialization_version +9 + +``` + +At any given time, the export APIs may support a range +of serialization versions. You can control which serialization +version to use using the `--jax-serialization-version` flag +or the `JAX_SERIALIZATION_VERSION` environment variable: + +```python +>>> from jax import export +>>> (export.minimum_supported_serialization_version, export.maximum_supported_serialization_version) +(9, 9) + +>>> from jax._src import config +>>> with config.jax_serialization_version(9): +... exp = export.export(jnp.cos)(1.) +... exp.mlir_module_serialization_version +9 + +``` + +We reserve the right to remove support for +generating or consuming serialization versions older than 6 months. + +### Module calling convention + +The `Exported.mlir_module` has a `main` function that takes an optional first +platform index argument if the module supports multiple platforms +(`len(lowering_platforms) > 1`), followed by the token arguments corresponding +to the ordered effects, followed by the kept array +arguments (corresponding to `module_kept_var_idx` and `in_avals`). +The platform index is a i32 or i64 scalar encoding the index of the current +compilation platform into the `lowering_platforms` sequence. + +Inner functions use a different calling convention: an optional +platform index argument, optional dimension variable arguments +(scalar tensors of type i32 or i64), +followed by optional token arguments (in presence of ordered effects), +followed by the regular array arguments. +The dimension arguments correspond to the dimension variables appearing in +the `args_avals`, in sorted order of their names. + +Consider the lowering of a function with one array argument of type +`f32[w, 2 * h]`, where `w` and `h` are two dimension variables. +Assume that we use multi-platform lowering, and we have +one ordered effect. The `main` function will be as follows: + +``` + func public main( + platform_index: i32 {jax.global_constant="_platform_index"}, + token_in: token, + arg: f32[?, ?]) { + arg_w = hlo.get_dimension_size(arg, 0) + dim1 = hlo.get_dimension_size(arg, 1) + arg_h = hlo.floordiv(dim1, 2) + call _check_shape_assertions(arg) # See below + token = new_token() + token_out, res = call _wrapped_jax_export_main(platform_index, + arg_h, + arg_w, + token_in, + arg) + return token_out, res + } +``` + +The actual computation is in `_wrapped_jax_export_main`, taking also +the values of `h` and `w` dimension variables. + +The signature of the `_wrapped_jax_export_main` is: + +``` + func private _wrapped_jax_export_main( + platform_index: i32 {jax.global_constant="_platform_index"}, + arg_h: i32 {jax.global_constant="h"}, + arg_w: i32 {jax.global_constant="w"}, + arg_token: stablehlo.token {jax.token=True}, + arg: f32[?, ?]) -> (stablehlo.token, ...) +``` + +Prior to serialization version 9 the calling convention for effects was +different: the `main` function does not take or return a token. Instead +the function creates dummy tokens of type `i1[0]` and passes them to the +`_wrapped_jax_export_main`. The `_wrapped_jax_export_main` +takes dummy tokens of type `i1[0]` and will create internally real +tokens to pass to the inner functions. The inner functions use real +tokens (both before and after serialization version 9) + +Also starting with serialization version 9, function arguments that contain +the platform index or the dimension variable values have a +`jax.global_constant` string attribute whose value is the name of the +global constant, either `_platform_index` or a dimension variable name. +The global constant name may be empty if it is not known. +Some global constant computations use inner functions, e.g., for +`floor_divide`. The arguments of such functions have a `jax.global_constant` +attribute for all attributes, meaning that the result of the function is +also a global constant. + +Note that `main` contains a call to `_check_shape_assertions`. +JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` +have values >= 1. We must check these constraints when we invoke the +module. We use a special custom call `@shape_assertion` that takes +a boolean first operand, a string `error_message` attribute that may contain +format specifiers `{0}`, `{1}`, ..., and a variadic number of integer +scalar operands corresponding to the format specifiers. + +``` + func private _check_shape_assertions(arg: f32[?, ?]) { + # Check that w is >= 1 + arg_w = hlo.get_dimension_size(arg, 0) + custom_call @shape_assertion(arg_w >= 1, arg_w, + error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") + # Check that dim1 is even + dim1 = hlo.get_dimension_size(arg, 1) + custom_call @shape_assertion(dim1 % 2 == 0, dim1, + error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") + # Check that h >= 1 + arg_h = hlo.floordiv(dim1, 2) + custom_call @shape_assertion(arg_h >= 1, arg_h, + error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") +``` + +### Serialization version numbers + +We list here a history of the serialization version numbers: + + * Version 1 used MHLO & CHLO to serialize the code, not supported anymore. + * Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported + anymore. + * Version 3 supports platform checking and multiple platforms. + Used from February 2023. Not supported anymore. + * Version 4 supports StableHLO with compatibility guarantees. + This is the earliest version at the time of the JAX native serialization + launch. + Used in JAX from March 15, 2023 (cl/516885716). Starting with + March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493). + The support for this version was dropped on + October 17th, 2023 (cl/573858283). + * Version 5 adds support for `call_tf_graph`. This is currently used + for some specialized use cases. Used in JAX from May 3rd, 2023 + (cl/529106145). + * Version 6 adds support for the `disabled_checks` attribute. This version + mandates a non-empty `platforms` attribute. Supported by XlaCallModule + since June 7th, 2023 and available in JAX since + June 13th, 2023 (JAX 0.4.13). + * Version 7 adds support for `stablehlo.shape_assertion` operations and + for `shape_assertions` specified in `disabled_checks`. + See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + since July 12th, 2023 (cl/547482522), + available in JAX serialization since July 20th, 2023 (JAX 0.4.14), + and the default since August 12th, 2023 (JAX 0.4.15). + * Version 8 adds support for the `jax.uses_shape_polymorphism` module + attribute and enables the shape refinement pass only when the + attribute is present. Supported by XlaCallModule since July 21st, 2023 + (cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14), + and the default since October 21st, 2023 (JAX 0.4.20). + * Version 9 adds support for effects. + See the docstring for `export.Exported` for the precise calling convention. + In this serialization version we also tag the platform index and the + dimension variables arguments with `jax.global_constant` attributes. + Supported by XlaCallModule since October 27th, 2023, + available in JAX since October 20th, 2023 (JAX 0.4.20), + and the default since February 1st, 2024 (JAX 0.4.24). + This is the only supported version as of 27th of March, 2024. + + + + + diff --git a/docs/export/index.rst b/docs/export/index.rst new file mode 100644 index 000000000..24cf2716c --- /dev/null +++ b/docs/export/index.rst @@ -0,0 +1,13 @@ +.. _export: + +Exporting and serialization +============================= + +.. toctree:: + :caption: Guides + :maxdepth: 2 + + export + shape_poly + + jax2tf diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md new file mode 100644 index 000000000..ae67963e2 --- /dev/null +++ b/docs/export/jax2tf.md @@ -0,0 +1,5 @@ +(jax2tf)= + +## Interoperation with TensorFlow + +TO DO: move here the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md new file mode 100644 index 000000000..931e8e160 --- /dev/null +++ b/docs/export/shape_poly.md @@ -0,0 +1,5 @@ +(shape_poly)= + +## Shape polymorphism + +TO DO: populate from the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion). diff --git a/docs/jax.export.rst b/docs/jax.export.rst new file mode 100644 index 000000000..59729b7a7 --- /dev/null +++ b/docs/jax.export.rst @@ -0,0 +1,36 @@ +``jax.export`` module +===================== + +.. automodule:: jax.export + +Classes +------- + +.. autosummary:: + :toctree: _autosummary + + Exported + DisabledSafetyCheck + +Functions +--------- + +.. autosummary:: + :toctree: _autosummary + + export + deserialize + minimum_supported_serialization_version + maximum_supported_serialization_version + default_lowering_platform + +Functions related to shape polymorphism +--------------------------------------- + +.. autosummary:: + :toctree: _autosummary + + symbolic_shape + symbolic_args_specs + is_symbolic_dim + SymbolicScope diff --git a/docs/jax.rst b/docs/jax.rst index 9979fb464..b112490a0 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -27,6 +27,7 @@ Subpackages jax.tree jax.tree_util jax.typing + jax.export jax.extend jax.example_libraries jax.experimental diff --git a/docs/user_guides.rst b/docs/user_guides.rst index c57609c17..f46d6b027 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -32,6 +32,7 @@ or deployed codebases. :caption: Run Time aot + export/index errors transfer_guard diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 03a476e62..433eac526 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -63,28 +63,35 @@ Shape = jax._src.core.Shape LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] HloSharding = xla_client.HloSharding -# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions -# for a description of the different versions. +"""The minimum supported serialization version. + +See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions +""" minimum_supported_serialization_version = 9 + +"""The maximum supported serialization version. + +See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions +""" maximum_supported_serialization_version = 9 class DisabledSafetyCheck: - """A safety check should be skipped on (de)serialization. + """A safety check that should be skipped on (de)serialization. Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, - e.g., as a sequence of string attributes to `jax_export.Exported` or of + e.g., as a sequence of string attributes to `jax.export.Exported` or of `tf.XlaCallModuleOp`. - You can disable more deserialization safety checks by passing - `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. + When using jax2tf, you can disable more deserialization safety checks + by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`. """ _impl: str @classmethod def platform(cls) -> DisabledSafetyCheck: - """Allows the execution platform to differ from the serialization platform. + """Allows the compilation platform to differ from the export platform. Has effect only on deserialization. """ @@ -102,7 +109,7 @@ class DisabledSafetyCheck: @classmethod def shape_assertions(cls) -> DisabledSafetyCheck: - """A noop. DEPRECATED. + """DEPRECATED: A noop. Was used previously to allow invocations with shapes that do not meet the constraints. Has no effect anymore, shape assertions cannot be disabled. @@ -149,7 +156,7 @@ class Exported: out_tree: a PyTreeDef describing the result of the lowered JAX function. out_avals: the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in - `in_avals. + `in_avals`. in_shardings_hlo: the flattened input shardings, a sequence as long as `in_avals`. `None` means unspecified sharding. Note that these do not include the mesh or the actual devices used in @@ -162,16 +169,17 @@ class Exported: into sharding specification that can be used with JAX APIs. nr_devices: the number of devices that the module has been lowered for. lowering_platforms: a tuple containing at least one of 'tpu', 'cpu', - 'cuda', 'rocm'. See below for the calling convention for when + 'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention + for the calling convention for when there are multiple lowering platforms. ordered_effects: the ordered effects present in the serialized module. - This is present from serialization version 9. See below for the - calling convention in presence of ordered effects. + This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention + for the calling convention in presence of ordered effects. unordered_effects: the unordered effects present in the serialized module. This is present from serialization version 9. mlir_module_serialized: the serialized lowered VHLO module. mlir_module_serialization_version: a version number for the serialized module. - See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. + See more versioning details at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. @@ -189,100 +197,7 @@ class Exported: for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs. - Calling convention for the exported module (for latest supported version): - - The `mlir_module` has a `main` function that takes an optional first - platform index argument if the module supports multiple platforms - (`len(lowering_platforms) > 1`), followed by the token arguments corresponding - to the ordered effects, followed by the kept array - arguments (corresponding to `module_kept_var_idx` and `in_avals`). - The platform index is a i32 or i64 scalar encoding the index of the current - compilation platform into the `lowering_platforms` sequence. - - Inner functions use a different calling convention: an optional - platform index argument, optional dimension variable arguments - (scalar tensors of type i32 or i64), - followed by optional token arguments (in presence of ordered effects), - followed by the regular array arguments. - The dimension arguments correspond to the dimension variables appearing in - the `args_avals`, in sorted order of their names. - - Consider the lowering of a function with one array argument of type "f32[w, - 2 * h]", where "w" and "h" are two dimension variables. - Assume that we use multi-platform lowering, and we have - one ordered effect. The `main` function will be as follows: - - func public main( - platform_index: i32 {jax.global_constant="_platform_index"}, - token_in: token, - arg: f32[?, ?]) { - arg_w = hlo.get_dimension_size(arg, 0) - dim1 = hlo.get_dimension_size(arg, 1) - arg_h = hlo.floordiv(dim1, 2) - call _check_shape_assertions(arg) # See below - token = new_token() - token_out, res = call _wrapped_jax_export_main(platform_index, - arg_h, - arg_w, - token_in, - arg) - return token_out, res - } - - The actual computation is in `_wrapped_jax_export_main`, taking also - the values of `h` and `w` dimension variables. - - The signature of the `_wrapped_jax_export_main` is: - - func private _wrapped_jax_export_main( - platform_index: i32 {jax.global_constant="_platform_index"}, - arg_h: i32 {jax.global_constant="h"}, - arg_w: i32 {jax.global_constant="w"}, - arg_token: stablehlo.token {jax.token=True}, - arg: f32[?, ?]) -> (stablehlo.token, ...) - - Prior to serialization version 9 the calling convention for effects is - different: the `main` function does not take or return a token. Instead - the function creates dummy tokens of type `i1[0]` and passes them to the - `_wrapped_jax_export_main`. The `_wrapped_jax_export_main` - takes dummy tokens of type `i1[0]` and will create internally real - tokens to pass to the inner functions. The inner functions use real - tokens (both before and after serialization version 9) - - Also starting with serialization version 9, function arguments that contain - the platform index or the dimension variable values have a - `jax.global_constant` string attribute whose value is the name of the - global constant, either `_platform_index` or a dimension variable name. - The global constant name may be empty if it is not known. - Some global constant computations use inner functions, e.g., for - `floor_divide`. The arguments of such functions have a `jax.global_constant` - attribute for all attributes, meaning that the result of the function is - also a global constant. - - Note that `main` contains a call to `_check_shape_assertions. - JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h` - have values >= 1. We must check these constraints when we invoke the - module. We use a special custom call `@shape_assertion` that takes - a boolean first operand, a string `error_message` attribute that may contain - format specifiers `{0}`, `{1}`, ..., and a variadic number of integer - scalar operands corresponding to the format specifiers. - - func private _check_shape_assertions(arg: f32[?, ?]) { - # Check that w is >= 1 - arg_w = hlo.get_dimension_size(arg, 0) - custom_call @shape_assertion(arg_w >= 1, arg_w, - error_message="Dimension variable 'w' must have integer value >= 1. Found {0}") - # Check that dim1 is even - dim1 = hlo.get_dimension_size(arg, 1) - custom_call @shape_assertion(dim1 % 2 == 0, dim1, - error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}") - # Check that h >= 1 - arg_h = hlo.floordiv(dim1, 2) - custom_call @shape_assertion(arg_h >= 1, arg_h, - error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}") - - If we `call_exported` with this module we perform these checks - statically (in `call_exported_abstract_eval`). + See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module_calling_convention). """ fun_name: str in_tree: tree_util.PyTreeDef @@ -417,7 +332,9 @@ def deserialize(blob: bytearray) -> Exported: def default_lowering_platform() -> str: - """Retrieves the default lowering platform for the exporting machine. + """Retrieves the default lowering platform. + + One of: `tpu`, `cpu`, `cuda`, `rocm`. """ # Canonicalize to turn 'gpu' into 'cuda' or 'rocm' return xb.canonicalize_platform(jax.default_backend()) @@ -458,6 +375,7 @@ def export_back_compat( Note: this function exists only for internal usage by jax2tf and for backwards compatibility with jax.experimental.export. Use `jax.export` instead. + See https://jax.readthedocs.io/en/latest/export.html Args: fun_jax: the function to lower and serialize. @@ -466,8 +384,8 @@ def export_back_compat( 'cuda', 'rocm'. If more than one platform is specified, then the lowered code takes an argument specifying the platform. If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. + The calling convention for multiple platforms is explained + at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. disabled_checks: the safety checks to disable. See docstring of `DisabledSafetyCheck`. @@ -547,32 +465,33 @@ def export( """Exports a JAX function for persistent serialization. Args: - fun_jit: the function to export. Should be the result of `jit`. + fun_jit: the function to export. Should be the result of `jax.jit`. lowering_platforms: Optional sequence containing a subset of 'tpu', 'cpu', 'cuda', 'rocm'. If more than one platform is specified, then the lowered code takes an argument specifying the platform. If None, then use the default JAX backend. - The calling convention for multiple platforms is explained in the - `jax_export.Exported` docstring. - disabled_checks: the safety checks to disable. See docstring - of `DisabledSafetyCheck`. + The calling convention for multiple platforms is explained at + https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. + disabled_checks: the safety checks to disable. See documentation for + of `jax.export.DisabledSafetyCheck`. - Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct, + Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`, or values with `.shape` and `.dtype` attributes, and returns an `Exported`. Usage: + >>> from jax import export >>> exported: export.Exported = export.export(jnp.sin)( ... np.arange(4, dtype=np.float32)) - - # You can inspect the Exported object + >>> + >>> # You can inspect the Exported object >>> exported.in_avals (ShapedArray(float32[4]),) >>> blob: bytearray = exported.serialize() - - # The serialized bytes are safe to use in a separate process + >>> + >>> # The serialized bytes are safe to use in a separate process >>> rehydrated: export.Exported = export.deserialize(blob) >>> rehydrated.fun_name 'sin' @@ -776,11 +695,10 @@ def _wrap_main_func( ) -> ir.Module: """Wraps the lowered module with a new "main" handling dimension arguments. - See calling convention documentation for `jax_export.Exported`. + See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention. Args: - module: the HLO module as obtained from lowering. See the calling convention - for inner functions in `jax_export.Exported`. + module: the HLO module as obtained from lowering. args_avals_flat: the avals for all the arguments of the lowered function, which correspond to the array arguments of the `module`. args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error @@ -1343,7 +1261,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, for s in exported.in_shardings_hlo + exported.out_shardings_hlo)): err_msg = "the module contains non-replicated sharding annotations." if err_msg: - raise NotImplementedError( + raise ValueError( f"Exported module {exported.fun_name} was lowered for " f"{exported.nr_devices} devices and is called in a context with " f"{num_devices} devices. This is disallowed because: {err_msg}" diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index dd0ae856b..bb9b2ac50 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -982,7 +982,7 @@ class SymbolicScope: Holds the constraints on symbolic expressions. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints) + See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints) for more details. Args: @@ -1194,6 +1194,8 @@ def _convertible_to_poly(p: DimSize) -> bool: return isinstance(p, _DimExpr) or _convertible_to_int(p) def is_symbolic_dim(p: DimSize) -> bool: + """Checks if a dimension is symbolic. + """ return isinstance(p, _DimExpr) def is_poly_dim(p: DimSize) -> bool: @@ -1856,7 +1858,7 @@ class ShapeConstraints: def shape_assertions(self, eval: CachingShapeEvaluator) -> None: """Computes the shape assertions for the set of constraints. - See jax_export._wrap_main_func docstring. + See jax_export.Exported docstring. """ # We want to report the errors in the same order as `check_statically`. # So, we process them in order, in case some fail statically, and we diff --git a/tests/export_test.py b/tests/export_test.py index cb2d7bdc4..f86f9034d 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -965,12 +965,12 @@ class JaxExportTest(jtu.JaxTestCase): # Test error reporting with self.assertRaisesRegex( - NotImplementedError, + ValueError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): _ = exp.call(a) with self.assertRaisesRegex( - NotImplementedError, + ValueError, "Exported module .* was lowered for 2 devices and is called in a context with 1 device"): mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",)) _ = jax.jit( @@ -1046,7 +1046,7 @@ class JaxExportTest(jtu.JaxTestCase): b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) with self.assertRaisesRegex( - NotImplementedError, + ValueError, "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"): @@ -1092,7 +1092,7 @@ class JaxExportTest(jtu.JaxTestCase): b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i"))) with self.assertRaisesRegex( - NotImplementedError, + ValueError, "Exported module .* was lowered for 1 devices and is called in a " f"context with {jax.local_device_count()} devices.* module contains " "non-replicated sharding annotations"):