mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 13:56:07 +00:00
[export] Improved the documentation.
In particular added the docstring for `Exported.call` method, and fixed the formatting for `Exported.in_shardings_jax`.
This commit is contained in:
parent
ad00ee1e06
commit
cc73c50c41
@ -689,22 +689,21 @@ minimization phase.
|
||||
### Doctests
|
||||
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
You can run this using
|
||||
You can find the up-to-date command to run doctests in
|
||||
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml).
|
||||
E.g., you can run:
|
||||
|
||||
```
|
||||
pytest docs
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
|
||||
```
|
||||
|
||||
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
|
||||
function docstrings will run correctly. You can run this locally using, for example:
|
||||
|
||||
```
|
||||
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
|
||||
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py
|
||||
```
|
||||
|
||||
Keep in mind that there are several files that are marked to be skipped when the
|
||||
doctest command is run on the full package; you can see the details in
|
||||
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
|
||||
|
||||
## Type checking
|
||||
|
||||
|
@ -14,8 +14,11 @@ Classes
|
||||
.. autosummary::
|
||||
:toctree: _autosummary
|
||||
|
||||
Exported
|
||||
DisabledSafetyCheck
|
||||
.. autoclass:: Exported
|
||||
:members:
|
||||
|
||||
.. autoclass:: DisabledSafetyCheck
|
||||
:members:
|
||||
|
||||
Functions
|
||||
---------
|
||||
|
@ -203,6 +203,7 @@ class Exported:
|
||||
_get_vjp: Callable[[Exported], Exported] | None
|
||||
|
||||
def mlir_module(self) -> str:
|
||||
"""A string representation of the `mlir_module_serialized`."""
|
||||
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
|
||||
|
||||
def __str__(self):
|
||||
@ -211,8 +212,8 @@ class Exported:
|
||||
return f"Exported(fun_name={self.fun_name}, ...)"
|
||||
|
||||
def in_shardings_jax(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.in_shardings_hlo.
|
||||
|
||||
The Exported object stores `in_shardings_hlo` as HloShardings, which are
|
||||
@ -221,30 +222,31 @@ class Exported:
|
||||
`jax.device_put`.
|
||||
|
||||
Example usage:
|
||||
>>> from jax import export
|
||||
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
|
||||
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
|
||||
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
|
||||
... )(np.arange(jax.device_count()))
|
||||
>>> exp.in_shardings_hlo
|
||||
({devices=[8]<=[8]},)
|
||||
|
||||
# Create a mesh for running the exported object
|
||||
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
|
||||
>>>
|
||||
# Put the args and kwargs on the appropriate devices
|
||||
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
|
||||
... exp.in_shardings_jax(run_mesh)[0])
|
||||
>>> res = exp.call(run_arg)
|
||||
>>> res.addressable_shards
|
||||
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
|
||||
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
|
||||
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
|
||||
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
|
||||
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
|
||||
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
|
||||
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
|
||||
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
|
||||
>>> from jax import export
|
||||
>>> # Prepare the exported object:
|
||||
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
|
||||
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
|
||||
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
|
||||
... )(np.arange(jax.device_count()))
|
||||
>>> exp.in_shardings_hlo
|
||||
({devices=[8]<=[8]},)
|
||||
>>> # Create a mesh for running the exported object
|
||||
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
|
||||
>>> # Put the args and kwargs on the appropriate devices
|
||||
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
|
||||
... exp.in_shardings_jax(run_mesh)[0])
|
||||
>>> res = exp.call(run_arg)
|
||||
>>> res.addressable_shards
|
||||
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
|
||||
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
|
||||
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
|
||||
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
|
||||
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
|
||||
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
|
||||
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
|
||||
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
|
||||
|
||||
"""
|
||||
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
|
||||
for s in self.in_shardings_hlo)
|
||||
@ -252,7 +254,7 @@ class Exported:
|
||||
def out_shardings_jax(
|
||||
self,
|
||||
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
|
||||
"""Creates Shardings corresponding to self.out_shardings_hlo.
|
||||
"""Creates Shardings corresponding to `self.out_shardings_hlo`.
|
||||
|
||||
See documentation for in_shardings_jax.
|
||||
"""
|
||||
@ -289,6 +291,21 @@ class Exported:
|
||||
return serialize(self, vjp_order=vjp_order)
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
"""Call an exported function from a JAX program.
|
||||
|
||||
Args:
|
||||
args: the positional arguments to pass to the exported function. This
|
||||
should be a pytree of arrays with the same pytree structure as the
|
||||
arguments for which the function was exported.
|
||||
kwargs: the keyword arguments to pass to the exported function.
|
||||
|
||||
Returns: a pytree of result array, with the same structure as the
|
||||
results of the exported function.
|
||||
|
||||
The invocation supports reverse-mode AD, and all the features supported
|
||||
by exporting: shape polymorphism, multi-platform, device polymorphism.
|
||||
See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
|
||||
"""
|
||||
return call_exported(self)(*args, **kwargs)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user