1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 13:56:07 +00:00

[export] Improved the documentation.

In particular added the docstring for `Exported.call` method,
and fixed the formatting for `Exported.in_shardings_jax`.
This commit is contained in:
George Necula 2024-12-08 17:39:24 +01:00
parent ad00ee1e06
commit cc73c50c41
3 changed files with 53 additions and 34 deletions

@ -689,22 +689,21 @@ minimization phase.
### Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using
You can find the up-to-date command to run doctests in
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml).
E.g., you can run:
```
pytest docs
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
```
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
function docstrings will run correctly. You can run this locally using, for example:
```
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py
```
Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
## Type checking

@ -14,8 +14,11 @@ Classes
.. autosummary::
:toctree: _autosummary
Exported
DisabledSafetyCheck
.. autoclass:: Exported
:members:
.. autoclass:: DisabledSafetyCheck
:members:
Functions
---------

@ -203,6 +203,7 @@ class Exported:
_get_vjp: Callable[[Exported], Exported] | None
def mlir_module(self) -> str:
"""A string representation of the `mlir_module_serialized`."""
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
def __str__(self):
@ -211,8 +212,8 @@ class Exported:
return f"Exported(fun_name={self.fun_name}, ...)"
def in_shardings_jax(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.in_shardings_hlo.
The Exported object stores `in_shardings_hlo` as HloShardings, which are
@ -221,30 +222,31 @@ class Exported:
`jax.device_put`.
Example usage:
>>> from jax import export
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
... )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
# Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>>
# Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
>>> from jax import export
>>> # Prepare the exported object:
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
... )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
>>> # Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>> # Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
... exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
"""
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
for s in self.in_shardings_hlo)
@ -252,7 +254,7 @@ class Exported:
def out_shardings_jax(
self,
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
"""Creates Shardings corresponding to self.out_shardings_hlo.
"""Creates Shardings corresponding to `self.out_shardings_hlo`.
See documentation for in_shardings_jax.
"""
@ -289,6 +291,21 @@ class Exported:
return serialize(self, vjp_order=vjp_order)
def call(self, *args, **kwargs):
"""Call an exported function from a JAX program.
Args:
args: the positional arguments to pass to the exported function. This
should be a pytree of arrays with the same pytree structure as the
arguments for which the function was exported.
kwargs: the keyword arguments to pass to the exported function.
Returns: a pytree of result array, with the same structure as the
results of the exported function.
The invocation supports reverse-mode AD, and all the features supported
by exporting: shape polymorphism, multi-platform, device polymorphism.
See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
"""
return call_exported(self)(*args, **kwargs)