mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[export] Add documentation for jax.export
This commit is contained in:
parent
ad9f35ae53
commit
105cc9a103
1
.github/workflows/ci-build.yaml
vendored
1
.github/workflows/ci-build.yaml
vendored
@ -136,6 +136,7 @@ jobs:
|
||||
- name: Test documentation
|
||||
env:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
JAX_TRACEBACK_FILTERING: "off"
|
||||
JAX_ARRAY: 1
|
||||
PY_COLORS: 1
|
||||
run: |
|
||||
|
13
docs/aot.md
13
docs/aot.md
@ -35,8 +35,6 @@ way. An example:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> import numpy as np
|
||||
|
||||
>>> def f(x, y): return 2 * x + y
|
||||
>>> x, y = 3, 4
|
||||
@ -66,6 +64,10 @@ Array(10, dtype=int32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
Note that the lowered objects can be used only in the same process
|
||||
in which they were lowered. For exporting use cases,
|
||||
see the {ref}`export` APIs.
|
||||
|
||||
See the {mod}`jax.stages` documentation for more details on what functionality
|
||||
the lowering and compiled functions provide.
|
||||
|
||||
@ -139,11 +141,16 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 :
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
}
|
||||
|
||||
>>> lowered_with_x.compile()(5)
|
||||
Array(19, dtype=int32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
The result of `lower` is not safe to serialize directly for use
|
||||
in a different process.
|
||||
See {ref}`export` for additional APIs for this purpose.
|
||||
|
||||
Note that `lower` here takes two arguments as usual, but the subsequent compiled
|
||||
function accepts only the remaining non-static second argument. The static first
|
||||
argument (value 7) is taken as a constant at lowering time and built into the
|
||||
@ -194,7 +201,7 @@ Array([[ 1., 5., 9.],
|
||||
|
||||
>>> jax.vmap(g_aot)(zs) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
|
||||
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>
|
||||
|
||||
```
|
||||
|
||||
|
633
docs/export/export.md
Normal file
633
docs/export/export.md
Normal file
@ -0,0 +1,633 @@
|
||||
# Exporting and serializing staged-out computations
|
||||
|
||||
The {ref}`ahead-of-time-lowering` APIs produce
|
||||
objects that can be used for debugging or for compilation and
|
||||
execution in the same process.
|
||||
Sometimes you want to serialize a lowered JAX function for
|
||||
compilation and execution in a separate process, perhaps
|
||||
at a later time. This would allow you to:
|
||||
|
||||
* compile and execute the function in another process or machine
|
||||
without requiring access to the JAX program,
|
||||
and without having to repeat the staging-out and lowering, e.g.,
|
||||
in an inference system.
|
||||
* trace and lower a function on a machine that does not have access
|
||||
to the accelerator for which you want to later compile and execute
|
||||
the function.
|
||||
* archive a snapshot of a JAX function, e.g., to be able to
|
||||
reproduce later your results. **Note:** check out the [compatibility
|
||||
guarantees](#compatibility-guarantees) for this use case.
|
||||
|
||||
Here is an example:
|
||||
|
||||
```python
|
||||
>>> import re
|
||||
>>> import numpy as np
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
|
||||
>>> def f(x): return 2 * x * x
|
||||
|
||||
|
||||
>>> exported: export.Exported = export.export(jax.jit(f))(
|
||||
... jax.ShapeDtypeStruct((), np.float32))
|
||||
|
||||
>>> # You can inspect the Exported object
|
||||
>>> exported.fun_name
|
||||
'f'
|
||||
|
||||
>>> exported.in_avals
|
||||
(ShapedArray(float32[]),)
|
||||
|
||||
>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
|
||||
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
|
||||
>>> # And you can serialize the Exported to a bytearray.
|
||||
>>> serialized: bytearray = exported.serialize()
|
||||
|
||||
>>> # The serialized function can later be rehydrated and called from
|
||||
>>> # another JAX computation, possibly in another process.
|
||||
>>> rehydrated_exp: export.Exported = export.deserialize(serialized)
|
||||
>>> rehydrated_exp.in_avals
|
||||
(ShapedArray(float32[]),)
|
||||
|
||||
>>> def callee(y):
|
||||
... return 3. * rehydrated_exp.call(y * 4.)
|
||||
|
||||
>>> callee(1.)
|
||||
Array(96., dtype=float32)
|
||||
|
||||
```
|
||||
|
||||
Serialization is broken down into two stages:
|
||||
1. exporting to produce an {class}`jax.export.Exported` object that contains
|
||||
the StableHLO for the lowered function along with the metadata necessary to
|
||||
call it from another JAX function. We have plans to add code to generate
|
||||
`Exported` objects from TensorFlow, and to use `Exported` objects from
|
||||
TensorFlow and PyTorch.
|
||||
2. the actual serialization to a byte array using the flatbuffers format.
|
||||
See {ref}`jax2tf` for
|
||||
an alternative serialization to TensorFlow graph that can be used
|
||||
for interoperation with TensorFlow.
|
||||
|
||||
## Support for reverse-mode AD
|
||||
|
||||
Serialization can optionally support higher-order reverse-mode AD. This is done
|
||||
by serializing the {func}`jax.vjp` of the primal function along with the primal function,
|
||||
up to a user-specified order (default is 0, meaning that the rehydrated
|
||||
function cannot be differentiated):
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from typing import Callable
|
||||
|
||||
>>> def f(x): return 7 * x * x * x
|
||||
|
||||
>>> # Serialize 3 levels of VJP along with the primal function
|
||||
>>> blob: bytearray = export.export(jax.jit(f))(1.).serialize(vjp_order=3)
|
||||
>>> rehydrated_f: Callable = export.deserialize(blob).call
|
||||
|
||||
>>> rehydrated_f(0.1) # 7 * 0.1^3
|
||||
Array(0.007, dtype=float32)
|
||||
|
||||
>>> jax.grad(rehydrated_f)(0.1) # 7*3 * 0.1^2
|
||||
Array(0.21000001, dtype=float32)
|
||||
|
||||
>>> jax.grad(jax.grad(rehydrated_f))(0.1) # 7*3*2 * 0.1
|
||||
Array(4.2, dtype=float32)
|
||||
|
||||
>>> jax.grad(jax.grad(jax.grad(rehydrated_f)))(0.1) # 7*3*2
|
||||
Array(42., dtype=float32)
|
||||
|
||||
>>> jax.grad(jax.grad(jax.grad(jax.grad(rehydrated_f))))(0.1) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
ValueError: No VJP is available
|
||||
|
||||
```
|
||||
|
||||
Note that the VJP function is computed lazily while serializing,
|
||||
when the JAX program is still available.
|
||||
This means that it respects all features of JAX VJP,
|
||||
e.g., {func}`jax.custom_vjp` and {func}`jax.remat`.
|
||||
|
||||
Note that the rehydrated function does not support any other
|
||||
transformations, e.g., forward-mode AD (jvp), or {func}`jax.vmap`.
|
||||
|
||||
## Compatibility guarantees
|
||||
|
||||
You should not use the raw StableHLO that is obtained from just lowering
|
||||
(`jax.jit(f).lower(1.).compiler_ir()`)
|
||||
for archival and for compilation in another process, for several reasons.
|
||||
|
||||
First, the compilation may use a different version of the compiler, supporting a
|
||||
different version of StableHLO. The {class}`jax.export` module takes
|
||||
care of this by using the
|
||||
[portable-artifact feature of StableHLO](https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md)
|
||||
to deal with the possible evolution of the StableHLO opset.
|
||||
|
||||
### Compatibility guarantees for custom calls
|
||||
|
||||
Second, the raw StableHLO may contain custom calls referencing C++
|
||||
functions.
|
||||
JAX uses custom calls for lowering of a small number of primitives,
|
||||
e.g., linear algebra primitives, sharding annotations, or Pallas kernels.
|
||||
These do not fall under the compatibility guarantees for StableHLO.
|
||||
The C++ implementations of these functions change rarely, but they can change.
|
||||
|
||||
`jax.export` makes the following export compatibility guarantees:
|
||||
A JAX exported artifact can be compiled and executed by a compiler and
|
||||
JAX runtime system that are:
|
||||
|
||||
* **up to 6 months newer** than the version of JAX used for exporting
|
||||
(we say that JAX export offers **6 months backward compatibility**).
|
||||
This is useful if we want to archive the exported artifact to be compiled and executed later.
|
||||
* **up to 3 weeks older** than the version of JAX used for exporting
|
||||
(we say that JAX export offers **3 weeks forward compatibility**).
|
||||
This is useful if we want to compile and run an exported artifact with a
|
||||
consumer that was built and deployed before the export, e.g.,
|
||||
an inference system that is already deployed when the exporting is done.
|
||||
|
||||
(The particular compatibility window lengths are the same that JAX
|
||||
[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model),
|
||||
and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow).
|
||||
The terminology “backward compatibility” is from the perspective of the consumer,
|
||||
e.g., the inference system.)
|
||||
|
||||
What **matters is when the exporting and consuming components were built**,
|
||||
not the time when the exporting and the compilation happen.
|
||||
For external JAX users, it is
|
||||
[possible to run JAX and jaxlib at different versions](https://jax.readthedocs.io/en/latest/jep/9419-jax-versioning.html#how-are-jax-and-jaxlib-versioned);
|
||||
what matters is when the jaxlib release was built.
|
||||
|
||||
To reduce chances of incompatibility, internal JAX users should:
|
||||
* **rebuild and redeploy consumer systems as frequently as possible**.
|
||||
|
||||
and external users should:
|
||||
* run the exporting and consumer systems with the same version of jaxlib, whenever possible, and
|
||||
* export for archival **with the latest released version of jaxlib**.
|
||||
|
||||
The compatibility guarantees do not apply if you bypass the `jax.export` APIs
|
||||
to obtain the StableHLO code.
|
||||
|
||||
Only a subset of custom calls are guaranteed stable and have
|
||||
compatibility guarantees ([see list](https://github.com/search?q=repo%3Agoogle%2Fjax%20_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE&type=code)).
|
||||
We continuously
|
||||
add more custom call targets to the allowed list along with backwards
|
||||
compatibility tests. If you try to serialize
|
||||
code that invokes other custom call targets you will get an error
|
||||
during exporting.
|
||||
|
||||
If you want to disable this safety check for a specific custom call,
|
||||
e.g., with target `my_target`, you can add
|
||||
`export.DisabledSafetyCheck.custom_call("my_target")` to the
|
||||
`disabled_checks` parameter of the `export` method,
|
||||
as in the following example:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax import lax
|
||||
>>> from jax._src.interpreters import mlir
|
||||
|
||||
>>> # override the lowering rule for sin to use a custom call `my_new_sin`
|
||||
>>> _ = mlir.register_lowering(lax.sin_p, lambda ctx, o: mlir.custom_call("my_new_sin", operands=[o], result_types=[o.type]).results)
|
||||
>>> print(jax.jit(lax.sin).lower(1.).compiler_ir())
|
||||
module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
|
||||
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
|
||||
%0 = stablehlo.custom_call @my_new_sin(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
}
|
||||
|
||||
>>> # If we try to export, we get an error
|
||||
>>> export.export(jax.jit(lax.sin))(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
ValueError: Cannot serialize code with custom calls whose targets have no compatibility guarantees: my_new_sin
|
||||
|
||||
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.custom_call`
|
||||
>>> exp = export.export(
|
||||
... jax.jit(lax.sin),
|
||||
... disabled_checks=[export.DisabledSafetyCheck.custom_call("my_new_sin")])(1.)
|
||||
|
||||
```
|
||||
|
||||
## Cross-platform and multi-platform export
|
||||
|
||||
JAX lowering is platform specific for a small number of JAX primitives.
|
||||
By default, the code is lowered and exported for the accelerator
|
||||
present on the exporting machine:
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> export.default_lowering_platform()
|
||||
'cpu'
|
||||
|
||||
```
|
||||
|
||||
There is a safety check that will be raise an error when trying to compile
|
||||
an `Exported` object on a machine that does not have the accelerator
|
||||
for which the code was exported.
|
||||
|
||||
You can specify explicitly for what platforms the code should be exported.
|
||||
This allows you to specify a different accelerator than you have
|
||||
available at export time,
|
||||
and it even allows you to specify multi-platform lexport to
|
||||
obtain an `Exported` object that can be compiled and executed
|
||||
on multiple platforms.
|
||||
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax import lax
|
||||
|
||||
>>> # You can specify the lowering_platform, e.g., `tpu`, `cpu`, `cuda`, `rocm`
|
||||
>>> # even if the current machine does not have that accelerator.
|
||||
>>> exp = export.export(jax.jit(lax.cos), lowering_platforms=['tpu'])(1.)
|
||||
|
||||
>>> # But you will get an error if you try to compile `exp`
|
||||
>>> # on a machine that does not have TPUs.
|
||||
>>> exp.call(1.) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
ValueError: The exported function 'cos' was lowered for platforms '('tpu',)' but it is used on '('cpu',)'.
|
||||
|
||||
>>> # We can avoid the error if we pass a `DisabledSafetyCheck.platform`
|
||||
>>> # parameter to `export`, e.g., because you have reasons to believe
|
||||
>>> # that the code lowered will run adequately on the current
|
||||
>>> # compilation platform (which is the case for `cos` in this
|
||||
>>> # example):
|
||||
>>> exp_unsafe = export.export(jax.jit(lax.cos),
|
||||
... lowering_platforms=['tpu'],
|
||||
... disabled_checks=[export.DisabledSafetyCheck.platform()])(1.)
|
||||
|
||||
>>> exp_unsafe.call(1.)
|
||||
Array(0.5403023, dtype=float32, weak_type=True)
|
||||
|
||||
# and similarly with multi-platform lowering
|
||||
>>> exp_multi = export.export(jax.jit(lax.cos),
|
||||
... lowering_platforms=['tpu', 'cpu', 'cuda'])(1.)
|
||||
>>> exp_multi.call(1.)
|
||||
Array(0.5403023, dtype=float32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
For multi-platform export, the StableHLO will contain multiple
|
||||
lowerings but only for those primitives that require it, so the
|
||||
resulting module size should be only marginally larger than the
|
||||
size of a module with default export.
|
||||
As an extreme case, when serializing a module without any
|
||||
primitives with platform-specific lowering, you will get
|
||||
the same StableHLO as for the single-plaform export.
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax import lax
|
||||
>>> # A largish function
|
||||
>>> def f(x):
|
||||
... for i in range(1000):
|
||||
... x = jnp.cos(x)
|
||||
... return x
|
||||
|
||||
>>> exp_single = export.export(jax.jit(f))(1.)
|
||||
>>> len(exp_single.mlir_module_serialized) # doctest: +SKIP
|
||||
9220
|
||||
|
||||
>>> exp_multi = export.export(jax.jit(f),
|
||||
... lowering_platforms=["cpu", "tpu", "cuda"])(1.)
|
||||
>>> len(exp_multi.mlir_module_serialized) # doctest: +SKIP
|
||||
9282
|
||||
|
||||
```
|
||||
|
||||
## Shape polymorphic export
|
||||
|
||||
When used in JIT mode, JAX will trace and lower a function separately
|
||||
for each combination of input shapes. When exporting, it is possible
|
||||
in some cases to use dimension variables for some input dimensions
|
||||
in order to obtain an exported artifact that can be used with multiple
|
||||
combinations of input shapes.
|
||||
|
||||
See the {ref}`shape_poly` documentation.
|
||||
|
||||
## Device polymorphic export
|
||||
|
||||
An exported artifact may contain sharding annotations for inputs,
|
||||
outputs and for some intermediates, but these annotations do not refer
|
||||
directly to the actual physical devices that existed at exporting time.
|
||||
Instead, the sharding annotations refer to logical devices. This
|
||||
means that you can compile and run the exported artifacts on different
|
||||
physical devices that were used for exporting.
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax.sharding import Mesh, NamedSharding
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
|
||||
>>> # Use the first 4 devices for exporting.
|
||||
>>> export_devices = jax.local_devices()[:4]
|
||||
>>> export_mesh = Mesh(export_devices, ("a",))
|
||||
>>> def f(x):
|
||||
... return x.T
|
||||
|
||||
>>> arg = jnp.arange(8 * len(export_devices))
|
||||
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
|
||||
|
||||
>>> # `exp` knows for how many devices it was exported.
|
||||
>>> exp.nr_devices
|
||||
4
|
||||
|
||||
>>> # and it knows the shardings for the inputs. These will be applied
|
||||
>>> # when the exported is called.
|
||||
>>> exp.in_shardings_hlo
|
||||
({devices=[4]<=[4]},)
|
||||
|
||||
>>> res1 = exp.call(jax.device_put(arg,
|
||||
... NamedSharding(export_mesh, P("a"))))
|
||||
|
||||
>>> # Check out the first 2 shards of the result
|
||||
>>> [f"device={s.device} index={s.index}" for s in res1.addressable_shards[:2]]
|
||||
['device=TFRT_CPU_0 index=(slice(0, 8, None),)',
|
||||
'device=TFRT_CPU_1 index=(slice(8, 16, None),)']
|
||||
|
||||
>>> # We can call `exp` with some other 4 devices and another
|
||||
>>> # mesh with a different shape, as long as the number of devices is
|
||||
>>> # the same.
|
||||
>>> other_mesh = Mesh(np.array(jax.local_devices()[2:6]).reshape((2, 2)), ("b", "c"))
|
||||
>>> res2 = exp.call(jax.device_put(arg,
|
||||
... NamedSharding(other_mesh, P("b"))))
|
||||
|
||||
>>> # Check out the first 2 shards of the result. Notice that the output is
|
||||
>>> # sharded similarly; this means that the input was resharded according to the
|
||||
>>> # exp.in_shardings.
|
||||
>>> [f"device={s.device} index={s.index}" for s in res2.addressable_shards[:2]]
|
||||
['device=TFRT_CPU_2 index=(slice(0, 8, None),)',
|
||||
'device=TFRT_CPU_3 index=(slice(8, 16, None),)']
|
||||
|
||||
```
|
||||
|
||||
It is an error to try to invoke an exported artifact with a different number
|
||||
of devices than it was exported for:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax.sharding import Mesh, NamedSharding
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
|
||||
>>> export_devices = jax.local_devices()
|
||||
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
|
||||
>>> def f(x):
|
||||
... return x.T
|
||||
|
||||
>>> arg = jnp.arange(4 * len(export_devices))
|
||||
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
|
||||
|
||||
>>> exp.call(arg) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
ValueError: Exported module f was lowered for 8 devices and is called in a context with 1 devices. This is disallowed because: the module was lowered for more than 1 device.
|
||||
|
||||
```
|
||||
|
||||
There are helper functions to shard the inputs for calling an exported
|
||||
artifacts using a new mesh constructed at the call site:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax.sharding import Mesh, NamedSharding
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
|
||||
>>> export_devices = jax.local_devices()
|
||||
>>> export_mesh = Mesh(np.array(export_devices), ("a",))
|
||||
>>> def f(x):
|
||||
... return x.T
|
||||
|
||||
>>> arg = jnp.arange(4 * len(export_devices))
|
||||
>>> exp = export.export(jax.jit(f, in_shardings=(NamedSharding(export_mesh, P("a")),)))(arg)
|
||||
|
||||
>>> # Prepare the mesh for calling `exp`.
|
||||
>>> calling_mesh = Mesh(np.array(export_devices[::-1]), ("b",))
|
||||
|
||||
>>> # Shard the arg according to what `exp` expects.
|
||||
>>> sharded_arg = jax.device_put(arg, exp.in_shardings_jax(calling_mesh)[0])
|
||||
>>> res = exp.call(sharded_arg)
|
||||
|
||||
```
|
||||
|
||||
As a special facility, if a function was exported for 1 device and if it contains no
|
||||
sharding annotations, then it can be invoked on an argument of the same shape but sharded
|
||||
on multiple devices, and the compiler will shard the function appropriately:
|
||||
|
||||
```python
|
||||
```python
|
||||
>>> import jax
|
||||
>>> from jax import export
|
||||
>>> from jax.sharding import Mesh, NamedSharding
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
|
||||
>>> def f(x):
|
||||
... return jnp.cos(x)
|
||||
|
||||
>>> arg = jnp.arange(4)
|
||||
>>> exp = export.export(jax.jit(f))(arg)
|
||||
>>> exp.in_avals
|
||||
(ShapedArray(int32[4]),)
|
||||
|
||||
>>> exp.nr_devices
|
||||
1
|
||||
|
||||
>>> # Prepare the mesh for calling `exp`.
|
||||
>>> calling_mesh = Mesh(jax.local_devices()[:4], ("b",))
|
||||
|
||||
>>> # Shard the arg according to what `exp` expects.
|
||||
>>> sharded_arg = jax.device_put(arg,
|
||||
... NamedSharding(calling_mesh, P("b")))
|
||||
>>> res = exp.call(sharded_arg)
|
||||
|
||||
```
|
||||
|
||||
## Module serialization versions
|
||||
|
||||
The JAX export support has evolved over time, e.g., to support
|
||||
effects. In order to support compatibility (see [compatibility guarantees](#compatibility-guarantees))
|
||||
we maintain a serialization version for each `Exported`.
|
||||
As of June 2024, all modules are serialized with version 9
|
||||
(the latest, see [all serialization versions](#serialization-version-numbers)):
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> exp: export.Exported = export.export(jnp.cos)(1.)
|
||||
>>> exp.mlir_module_serialization_version
|
||||
9
|
||||
|
||||
```
|
||||
|
||||
At any given time, the export APIs may support a range
|
||||
of serialization versions. You can control which serialization
|
||||
version to use using the `--jax-serialization-version` flag
|
||||
or the `JAX_SERIALIZATION_VERSION` environment variable:
|
||||
|
||||
```python
|
||||
>>> from jax import export
|
||||
>>> (export.minimum_supported_serialization_version, export.maximum_supported_serialization_version)
|
||||
(9, 9)
|
||||
|
||||
>>> from jax._src import config
|
||||
>>> with config.jax_serialization_version(9):
|
||||
... exp = export.export(jnp.cos)(1.)
|
||||
... exp.mlir_module_serialization_version
|
||||
9
|
||||
|
||||
```
|
||||
|
||||
We reserve the right to remove support for
|
||||
generating or consuming serialization versions older than 6 months.
|
||||
|
||||
### Module calling convention
|
||||
|
||||
The `Exported.mlir_module` has a `main` function that takes an optional first
|
||||
platform index argument if the module supports multiple platforms
|
||||
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
|
||||
to the ordered effects, followed by the kept array
|
||||
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
|
||||
The platform index is a i32 or i64 scalar encoding the index of the current
|
||||
compilation platform into the `lowering_platforms` sequence.
|
||||
|
||||
Inner functions use a different calling convention: an optional
|
||||
platform index argument, optional dimension variable arguments
|
||||
(scalar tensors of type i32 or i64),
|
||||
followed by optional token arguments (in presence of ordered effects),
|
||||
followed by the regular array arguments.
|
||||
The dimension arguments correspond to the dimension variables appearing in
|
||||
the `args_avals`, in sorted order of their names.
|
||||
|
||||
Consider the lowering of a function with one array argument of type
|
||||
`f32[w, 2 * h]`, where `w` and `h` are two dimension variables.
|
||||
Assume that we use multi-platform lowering, and we have
|
||||
one ordered effect. The `main` function will be as follows:
|
||||
|
||||
```
|
||||
func public main(
|
||||
platform_index: i32 {jax.global_constant="_platform_index"},
|
||||
token_in: token,
|
||||
arg: f32[?, ?]) {
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
call _check_shape_assertions(arg) # See below
|
||||
token = new_token()
|
||||
token_out, res = call _wrapped_jax_export_main(platform_index,
|
||||
arg_h,
|
||||
arg_w,
|
||||
token_in,
|
||||
arg)
|
||||
return token_out, res
|
||||
}
|
||||
```
|
||||
|
||||
The actual computation is in `_wrapped_jax_export_main`, taking also
|
||||
the values of `h` and `w` dimension variables.
|
||||
|
||||
The signature of the `_wrapped_jax_export_main` is:
|
||||
|
||||
```
|
||||
func private _wrapped_jax_export_main(
|
||||
platform_index: i32 {jax.global_constant="_platform_index"},
|
||||
arg_h: i32 {jax.global_constant="h"},
|
||||
arg_w: i32 {jax.global_constant="w"},
|
||||
arg_token: stablehlo.token {jax.token=True},
|
||||
arg: f32[?, ?]) -> (stablehlo.token, ...)
|
||||
```
|
||||
|
||||
Prior to serialization version 9 the calling convention for effects was
|
||||
different: the `main` function does not take or return a token. Instead
|
||||
the function creates dummy tokens of type `i1[0]` and passes them to the
|
||||
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
|
||||
takes dummy tokens of type `i1[0]` and will create internally real
|
||||
tokens to pass to the inner functions. The inner functions use real
|
||||
tokens (both before and after serialization version 9)
|
||||
|
||||
Also starting with serialization version 9, function arguments that contain
|
||||
the platform index or the dimension variable values have a
|
||||
`jax.global_constant` string attribute whose value is the name of the
|
||||
global constant, either `_platform_index` or a dimension variable name.
|
||||
The global constant name may be empty if it is not known.
|
||||
Some global constant computations use inner functions, e.g., for
|
||||
`floor_divide`. The arguments of such functions have a `jax.global_constant`
|
||||
attribute for all attributes, meaning that the result of the function is
|
||||
also a global constant.
|
||||
|
||||
Note that `main` contains a call to `_check_shape_assertions`.
|
||||
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
|
||||
have values >= 1. We must check these constraints when we invoke the
|
||||
module. We use a special custom call `@shape_assertion` that takes
|
||||
a boolean first operand, a string `error_message` attribute that may contain
|
||||
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
|
||||
scalar operands corresponding to the format specifiers.
|
||||
|
||||
```
|
||||
func private _check_shape_assertions(arg: f32[?, ?]) {
|
||||
# Check that w is >= 1
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
custom_call @shape_assertion(arg_w >= 1, arg_w,
|
||||
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
|
||||
# Check that dim1 is even
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
|
||||
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
|
||||
# Check that h >= 1
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
custom_call @shape_assertion(arg_h >= 1, arg_h,
|
||||
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
||||
```
|
||||
|
||||
### Serialization version numbers
|
||||
|
||||
We list here a history of the serialization version numbers:
|
||||
|
||||
* Version 1 used MHLO & CHLO to serialize the code, not supported anymore.
|
||||
* Version 2 supports StableHLO & CHLO. Used from October 2022. Not supported
|
||||
anymore.
|
||||
* Version 3 supports platform checking and multiple platforms.
|
||||
Used from February 2023. Not supported anymore.
|
||||
* Version 4 supports StableHLO with compatibility guarantees.
|
||||
This is the earliest version at the time of the JAX native serialization
|
||||
launch.
|
||||
Used in JAX from March 15, 2023 (cl/516885716). Starting with
|
||||
March 28th, 2023 we stopped using `dim_args_spec` (cl/520033493).
|
||||
The support for this version was dropped on
|
||||
October 17th, 2023 (cl/573858283).
|
||||
* Version 5 adds support for `call_tf_graph`. This is currently used
|
||||
for some specialized use cases. Used in JAX from May 3rd, 2023
|
||||
(cl/529106145).
|
||||
* Version 6 adds support for the `disabled_checks` attribute. This version
|
||||
mandates a non-empty `platforms` attribute. Supported by XlaCallModule
|
||||
since June 7th, 2023 and available in JAX since
|
||||
June 13th, 2023 (JAX 0.4.13).
|
||||
* Version 7 adds support for `stablehlo.shape_assertion` operations and
|
||||
for `shape_assertions` specified in `disabled_checks`.
|
||||
See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule
|
||||
since July 12th, 2023 (cl/547482522),
|
||||
available in JAX serialization since July 20th, 2023 (JAX 0.4.14),
|
||||
and the default since August 12th, 2023 (JAX 0.4.15).
|
||||
* Version 8 adds support for the `jax.uses_shape_polymorphism` module
|
||||
attribute and enables the shape refinement pass only when the
|
||||
attribute is present. Supported by XlaCallModule since July 21st, 2023
|
||||
(cl/549973693), available in JAX since July 26th, 2023 (JAX 0.4.14),
|
||||
and the default since October 21st, 2023 (JAX 0.4.20).
|
||||
* Version 9 adds support for effects.
|
||||
See the docstring for `export.Exported` for the precise calling convention.
|
||||
In this serialization version we also tag the platform index and the
|
||||
dimension variables arguments with `jax.global_constant` attributes.
|
||||
Supported by XlaCallModule since October 27th, 2023,
|
||||
available in JAX since October 20th, 2023 (JAX 0.4.20),
|
||||
and the default since February 1st, 2024 (JAX 0.4.24).
|
||||
This is the only supported version as of 27th of March, 2024.
|
||||
|
||||
|
||||
|
||||
|
||||
|
13
docs/export/index.rst
Normal file
13
docs/export/index.rst
Normal file
@ -0,0 +1,13 @@
|
||||
.. _export:
|
||||
|
||||
Exporting and serialization
|
||||
=============================
|
||||
|
||||
.. toctree::
|
||||
:caption: Guides
|
||||
:maxdepth: 2
|
||||
|
||||
export
|
||||
shape_poly
|
||||
|
||||
jax2tf
|
5
docs/export/jax2tf.md
Normal file
5
docs/export/jax2tf.md
Normal file
@ -0,0 +1,5 @@
|
||||
(jax2tf)=
|
||||
|
||||
## Interoperation with TensorFlow
|
||||
|
||||
TO DO: move here the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
|
5
docs/export/shape_poly.md
Normal file
5
docs/export/shape_poly.md
Normal file
@ -0,0 +1,5 @@
|
||||
(shape_poly)=
|
||||
|
||||
## Shape polymorphism
|
||||
|
||||
TO DO: populate from the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion).
|
36
docs/jax.export.rst
Normal file
36
docs/jax.export.rst
Normal file
@ -0,0 +1,36 @@
|
||||
``jax.export`` module
|
||||
=====================
|
||||
|
||||
.. automodule:: jax.export
|
||||
|
||||
Classes
|
||||
-------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Exported
|
||||
DisabledSafetyCheck
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
export
|
||||
deserialize
|
||||
minimum_supported_serialization_version
|
||||
maximum_supported_serialization_version
|
||||
default_lowering_platform
|
||||
|
||||
Functions related to shape polymorphism
|
||||
---------------------------------------
|
||||
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
symbolic_shape
|
||||
symbolic_args_specs
|
||||
is_symbolic_dim
|
||||
SymbolicScope
|
@ -27,6 +27,7 @@ Subpackages
|
||||
jax.tree
|
||||
jax.tree_util
|
||||
jax.typing
|
||||
jax.export
|
||||
jax.extend
|
||||
jax.example_libraries
|
||||
jax.experimental
|
||||
|
@ -32,6 +32,7 @@ or deployed codebases.
|
||||
:caption: Run Time
|
||||
|
||||
aot
|
||||
export/index
|
||||
errors
|
||||
transfer_guard
|
||||
|
||||
|
@ -63,28 +63,35 @@ Shape = jax._src.core.Shape
|
||||
LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue]
|
||||
HloSharding = xla_client.HloSharding
|
||||
|
||||
# See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions
|
||||
# for a description of the different versions.
|
||||
"""The minimum supported serialization version.
|
||||
|
||||
See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions
|
||||
"""
|
||||
minimum_supported_serialization_version = 9
|
||||
|
||||
"""The maximum supported serialization version.
|
||||
|
||||
See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions
|
||||
"""
|
||||
maximum_supported_serialization_version = 9
|
||||
|
||||
|
||||
class DisabledSafetyCheck:
|
||||
"""A safety check should be skipped on (de)serialization.
|
||||
"""A safety check that should be skipped on (de)serialization.
|
||||
|
||||
Most of these checks are performed on serialization, but some are deferred to
|
||||
deserialization. The list of disabled checks is attached to the serialization,
|
||||
e.g., as a sequence of string attributes to `jax_export.Exported` or of
|
||||
e.g., as a sequence of string attributes to `jax.export.Exported` or of
|
||||
`tf.XlaCallModuleOp`.
|
||||
|
||||
You can disable more deserialization safety checks by passing
|
||||
`TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
|
||||
When using jax2tf, you can disable more deserialization safety checks
|
||||
by passing `TF_XLA_FLAGS=--tf_xla_call_module_disabled_checks=platform`.
|
||||
"""
|
||||
_impl: str
|
||||
|
||||
@classmethod
|
||||
def platform(cls) -> DisabledSafetyCheck:
|
||||
"""Allows the execution platform to differ from the serialization platform.
|
||||
"""Allows the compilation platform to differ from the export platform.
|
||||
|
||||
Has effect only on deserialization.
|
||||
"""
|
||||
@ -102,7 +109,7 @@ class DisabledSafetyCheck:
|
||||
|
||||
@classmethod
|
||||
def shape_assertions(cls) -> DisabledSafetyCheck:
|
||||
"""A noop. DEPRECATED.
|
||||
"""DEPRECATED: A noop.
|
||||
|
||||
Was used previously to allow invocations with shapes that do not meet the
|
||||
constraints. Has no effect anymore, shape assertions cannot be disabled.
|
||||
@ -149,7 +156,7 @@ class Exported:
|
||||
out_tree: a PyTreeDef describing the result of the lowered JAX function.
|
||||
out_avals: the flat tuple of output abstract values. May contain dimension
|
||||
expressions in the shapes, with dimension variables among those in
|
||||
`in_avals.
|
||||
`in_avals`.
|
||||
in_shardings_hlo: the flattened input shardings, a sequence as long
|
||||
as `in_avals`. `None` means unspecified sharding.
|
||||
Note that these do not include the mesh or the actual devices used in
|
||||
@ -162,16 +169,17 @@ class Exported:
|
||||
into sharding specification that can be used with JAX APIs.
|
||||
nr_devices: the number of devices that the module has been lowered for.
|
||||
lowering_platforms: a tuple containing at least one of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. See below for the calling convention for when
|
||||
'cuda', 'rocm'. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
|
||||
for the calling convention for when
|
||||
there are multiple lowering platforms.
|
||||
ordered_effects: the ordered effects present in the serialized module.
|
||||
This is present from serialization version 9. See below for the
|
||||
calling convention in presence of ordered effects.
|
||||
This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export.html#module-calling-convention
|
||||
for the calling convention in presence of ordered effects.
|
||||
unordered_effects: the unordered effects present in the serialized module.
|
||||
This is present from serialization version 9.
|
||||
mlir_module_serialized: the serialized lowered VHLO module.
|
||||
mlir_module_serialization_version: a version number for the serialized module.
|
||||
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
|
||||
See more versioning details at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention#module-serialization-versions.
|
||||
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
|
||||
must be passed to the module. The other arguments have been dropped
|
||||
because they are not used.
|
||||
@ -189,100 +197,7 @@ class Exported:
|
||||
for each primal output. It returns a tuple with the cotangents
|
||||
corresponding to the flattened primal inputs.
|
||||
|
||||
Calling convention for the exported module (for latest supported version):
|
||||
|
||||
The `mlir_module` has a `main` function that takes an optional first
|
||||
platform index argument if the module supports multiple platforms
|
||||
(`len(lowering_platforms) > 1`), followed by the token arguments corresponding
|
||||
to the ordered effects, followed by the kept array
|
||||
arguments (corresponding to `module_kept_var_idx` and `in_avals`).
|
||||
The platform index is a i32 or i64 scalar encoding the index of the current
|
||||
compilation platform into the `lowering_platforms` sequence.
|
||||
|
||||
Inner functions use a different calling convention: an optional
|
||||
platform index argument, optional dimension variable arguments
|
||||
(scalar tensors of type i32 or i64),
|
||||
followed by optional token arguments (in presence of ordered effects),
|
||||
followed by the regular array arguments.
|
||||
The dimension arguments correspond to the dimension variables appearing in
|
||||
the `args_avals`, in sorted order of their names.
|
||||
|
||||
Consider the lowering of a function with one array argument of type "f32[w,
|
||||
2 * h]", where "w" and "h" are two dimension variables.
|
||||
Assume that we use multi-platform lowering, and we have
|
||||
one ordered effect. The `main` function will be as follows:
|
||||
|
||||
func public main(
|
||||
platform_index: i32 {jax.global_constant="_platform_index"},
|
||||
token_in: token,
|
||||
arg: f32[?, ?]) {
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
call _check_shape_assertions(arg) # See below
|
||||
token = new_token()
|
||||
token_out, res = call _wrapped_jax_export_main(platform_index,
|
||||
arg_h,
|
||||
arg_w,
|
||||
token_in,
|
||||
arg)
|
||||
return token_out, res
|
||||
}
|
||||
|
||||
The actual computation is in `_wrapped_jax_export_main`, taking also
|
||||
the values of `h` and `w` dimension variables.
|
||||
|
||||
The signature of the `_wrapped_jax_export_main` is:
|
||||
|
||||
func private _wrapped_jax_export_main(
|
||||
platform_index: i32 {jax.global_constant="_platform_index"},
|
||||
arg_h: i32 {jax.global_constant="h"},
|
||||
arg_w: i32 {jax.global_constant="w"},
|
||||
arg_token: stablehlo.token {jax.token=True},
|
||||
arg: f32[?, ?]) -> (stablehlo.token, ...)
|
||||
|
||||
Prior to serialization version 9 the calling convention for effects is
|
||||
different: the `main` function does not take or return a token. Instead
|
||||
the function creates dummy tokens of type `i1[0]` and passes them to the
|
||||
`_wrapped_jax_export_main`. The `_wrapped_jax_export_main`
|
||||
takes dummy tokens of type `i1[0]` and will create internally real
|
||||
tokens to pass to the inner functions. The inner functions use real
|
||||
tokens (both before and after serialization version 9)
|
||||
|
||||
Also starting with serialization version 9, function arguments that contain
|
||||
the platform index or the dimension variable values have a
|
||||
`jax.global_constant` string attribute whose value is the name of the
|
||||
global constant, either `_platform_index` or a dimension variable name.
|
||||
The global constant name may be empty if it is not known.
|
||||
Some global constant computations use inner functions, e.g., for
|
||||
`floor_divide`. The arguments of such functions have a `jax.global_constant`
|
||||
attribute for all attributes, meaning that the result of the function is
|
||||
also a global constant.
|
||||
|
||||
Note that `main` contains a call to `_check_shape_assertions.
|
||||
JAX tracing assumes that `arg.shape[1]` is even, and that both `w` and `h`
|
||||
have values >= 1. We must check these constraints when we invoke the
|
||||
module. We use a special custom call `@shape_assertion` that takes
|
||||
a boolean first operand, a string `error_message` attribute that may contain
|
||||
format specifiers `{0}`, `{1}`, ..., and a variadic number of integer
|
||||
scalar operands corresponding to the format specifiers.
|
||||
|
||||
func private _check_shape_assertions(arg: f32[?, ?]) {
|
||||
# Check that w is >= 1
|
||||
arg_w = hlo.get_dimension_size(arg, 0)
|
||||
custom_call @shape_assertion(arg_w >= 1, arg_w,
|
||||
error_message="Dimension variable 'w' must have integer value >= 1. Found {0}")
|
||||
# Check that dim1 is even
|
||||
dim1 = hlo.get_dimension_size(arg, 1)
|
||||
custom_call @shape_assertion(dim1 % 2 == 0, dim1,
|
||||
error_message="Dimension variable 'h' must have integer value >= 1. Found non-zero remainder {0}")
|
||||
# Check that h >= 1
|
||||
arg_h = hlo.floordiv(dim1, 2)
|
||||
custom_call @shape_assertion(arg_h >= 1, arg_h,
|
||||
error_message=""Dimension variable 'h' must have integer value >= 1. Found {0}")
|
||||
|
||||
If we `call_exported` with this module we perform these checks
|
||||
statically (in `call_exported_abstract_eval`).
|
||||
See a [description of the calling convention for the `mlir_module`](https://jax.readthedocs.io/en/latest/export.html#module_calling_convention).
|
||||
"""
|
||||
fun_name: str
|
||||
in_tree: tree_util.PyTreeDef
|
||||
@ -417,7 +332,9 @@ def deserialize(blob: bytearray) -> Exported:
|
||||
|
||||
|
||||
def default_lowering_platform() -> str:
|
||||
"""Retrieves the default lowering platform for the exporting machine.
|
||||
"""Retrieves the default lowering platform.
|
||||
|
||||
One of: `tpu`, `cpu`, `cuda`, `rocm`.
|
||||
"""
|
||||
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
|
||||
return xb.canonicalize_platform(jax.default_backend())
|
||||
@ -458,6 +375,7 @@ def export_back_compat(
|
||||
Note: this function exists only for internal usage by jax2tf and for
|
||||
backwards compatibility with jax.experimental.export. Use
|
||||
`jax.export` instead.
|
||||
See https://jax.readthedocs.io/en/latest/export.html
|
||||
|
||||
Args:
|
||||
fun_jax: the function to lower and serialize.
|
||||
@ -466,8 +384,8 @@ def export_back_compat(
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
If None, then use the default JAX backend.
|
||||
The calling convention for multiple platforms is explained in the
|
||||
`jax_export.Exported` docstring.
|
||||
The calling convention for multiple platforms is explained
|
||||
at https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
|
||||
disabled_checks: the safety checks to disable. See docstring
|
||||
of `DisabledSafetyCheck`.
|
||||
|
||||
@ -547,32 +465,33 @@ def export(
|
||||
"""Exports a JAX function for persistent serialization.
|
||||
|
||||
Args:
|
||||
fun_jit: the function to export. Should be the result of `jit`.
|
||||
fun_jit: the function to export. Should be the result of `jax.jit`.
|
||||
lowering_platforms:
|
||||
Optional sequence containing a subset of 'tpu', 'cpu',
|
||||
'cuda', 'rocm'. If more than one platform is specified, then
|
||||
the lowered code takes an argument specifying the platform.
|
||||
If None, then use the default JAX backend.
|
||||
The calling convention for multiple platforms is explained in the
|
||||
`jax_export.Exported` docstring.
|
||||
disabled_checks: the safety checks to disable. See docstring
|
||||
of `DisabledSafetyCheck`.
|
||||
The calling convention for multiple platforms is explained at
|
||||
https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
|
||||
disabled_checks: the safety checks to disable. See documentation for
|
||||
of `jax.export.DisabledSafetyCheck`.
|
||||
|
||||
Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
|
||||
Returns: a function that takes args and kwargs pytrees of {class}`jax.ShapeDtypeStruct`,
|
||||
or values with `.shape` and `.dtype` attributes, and returns an
|
||||
`Exported`.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> from jax import export
|
||||
>>> exported: export.Exported = export.export(jnp.sin)(
|
||||
... np.arange(4, dtype=np.float32))
|
||||
|
||||
# You can inspect the Exported object
|
||||
>>>
|
||||
>>> # You can inspect the Exported object
|
||||
>>> exported.in_avals
|
||||
(ShapedArray(float32[4]),)
|
||||
>>> blob: bytearray = exported.serialize()
|
||||
|
||||
# The serialized bytes are safe to use in a separate process
|
||||
>>>
|
||||
>>> # The serialized bytes are safe to use in a separate process
|
||||
>>> rehydrated: export.Exported = export.deserialize(blob)
|
||||
>>> rehydrated.fun_name
|
||||
'sin'
|
||||
@ -776,11 +695,10 @@ def _wrap_main_func(
|
||||
) -> ir.Module:
|
||||
"""Wraps the lowered module with a new "main" handling dimension arguments.
|
||||
|
||||
See calling convention documentation for `jax_export.Exported`.
|
||||
See calling convention documentation https://jax.readthedocs.io/en/latest/export.html#module-calling-convention.
|
||||
|
||||
Args:
|
||||
module: the HLO module as obtained from lowering. See the calling convention
|
||||
for inner functions in `jax_export.Exported`.
|
||||
module: the HLO module as obtained from lowering.
|
||||
args_avals_flat: the avals for all the arguments of the lowered function,
|
||||
which correspond to the array arguments of the `module`.
|
||||
args_kwargs_tree: the PyTreeDef corresponding to `(args, kwargs)`, for error
|
||||
@ -1343,7 +1261,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
for s in exported.in_shardings_hlo + exported.out_shardings_hlo)):
|
||||
err_msg = "the module contains non-replicated sharding annotations."
|
||||
if err_msg:
|
||||
raise NotImplementedError(
|
||||
raise ValueError(
|
||||
f"Exported module {exported.fun_name} was lowered for "
|
||||
f"{exported.nr_devices} devices and is called in a context with "
|
||||
f"{num_devices} devices. This is disallowed because: {err_msg}"
|
||||
|
@ -982,7 +982,7 @@ class SymbolicScope:
|
||||
|
||||
Holds the constraints on symbolic expressions.
|
||||
|
||||
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints)
|
||||
See [the README](https://jax.readthedocs.io/en/latest/export/shape_poly.html#user-specified-symbolic-constraints)
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
@ -1194,6 +1194,8 @@ def _convertible_to_poly(p: DimSize) -> bool:
|
||||
return isinstance(p, _DimExpr) or _convertible_to_int(p)
|
||||
|
||||
def is_symbolic_dim(p: DimSize) -> bool:
|
||||
"""Checks if a dimension is symbolic.
|
||||
"""
|
||||
return isinstance(p, _DimExpr)
|
||||
|
||||
def is_poly_dim(p: DimSize) -> bool:
|
||||
@ -1856,7 +1858,7 @@ class ShapeConstraints:
|
||||
def shape_assertions(self, eval: CachingShapeEvaluator) -> None:
|
||||
"""Computes the shape assertions for the set of constraints.
|
||||
|
||||
See jax_export._wrap_main_func docstring.
|
||||
See jax_export.Exported docstring.
|
||||
"""
|
||||
# We want to report the errors in the same order as `check_statically`.
|
||||
# So, we process them in order, in case some fail statically, and we
|
||||
|
@ -965,12 +965,12 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
# Test error reporting
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
_ = exp.call(a)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 2 devices and is called in a context with 1 device"):
|
||||
mesh1 = Mesh(jax.devices()[0:1], axis_names=("x",))
|
||||
_ = jax.jit(
|
||||
@ -1046,7 +1046,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
@ -1092,7 +1092,7 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(a, jax.sharding.NamedSharding(run_mesh, P("i")))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
ValueError,
|
||||
"Exported module .* was lowered for 1 devices and is called in a "
|
||||
f"context with {jax.local_device_count()} devices.* module contains "
|
||||
"non-replicated sharding annotations"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user