mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Adding @custom_partitioning to jax.experimental API doc
This commit is contained in:
parent
6a69b5a16c
commit
59c71250ee
9
docs/jax.experimental.custom_partitioning.rst
Normal file
9
docs/jax.experimental.custom_partitioning.rst
Normal file
@ -0,0 +1,9 @@
|
||||
``jax.experimental.custom_partitioning`` module
|
||||
===============================================
|
||||
|
||||
.. automodule:: jax.experimental.custom_partitioning
|
||||
|
||||
API
|
||||
---
|
||||
|
||||
.. autofunction:: custom_partitioning
|
@ -21,6 +21,7 @@ Experimental Modules
|
||||
jax.experimental.pjit
|
||||
jax.experimental.sparse
|
||||
jax.experimental.jet
|
||||
jax.experimental.custom_partitioning
|
||||
|
||||
Experimental APIs
|
||||
-----------------
|
||||
|
@ -130,39 +130,190 @@ def _default_propagate_user_shardings(sharding, shape):
|
||||
class custom_partitioning:
|
||||
"""Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.
|
||||
|
||||
Usage:
|
||||
```
|
||||
@custom_partitioning
|
||||
def f(*args):
|
||||
return ...
|
||||
Usage
|
||||
-----
|
||||
|
||||
def propagate_user_sharding(sharding, shape):
|
||||
'''Update the sharding of the op from a user's sharding.'''
|
||||
.. code-block:: python
|
||||
|
||||
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
||||
def lower_fn(*args):
|
||||
... builds computation on per-device shapes ...
|
||||
# result_sharding and arg_shardings may optionally be modified and the
|
||||
# partitioner will insert collectives to reshape.
|
||||
return lower_fn, result_sharding, arg_shardings
|
||||
@custom_partitioning
|
||||
def f(*args):
|
||||
return ...
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
||||
'''Compute the result sharding from the sharding of the operands.'''
|
||||
def propagate_user_sharding(sharding, shape):
|
||||
'''Update the sharding of the op from a user's sharding.'''
|
||||
|
||||
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
|
||||
```
|
||||
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
||||
def lower_fn(*args):
|
||||
... builds computation on per-device shapes ...
|
||||
# result_sharding and arg_shardings may optionally be modified and the
|
||||
# partitioner will insert collectives to reshape.
|
||||
return lower_fn, result_sharding, arg_shardings
|
||||
|
||||
The args to def_partition are as follows:
|
||||
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
||||
'''Compute the result sharding from the sharding of the operands.'''
|
||||
|
||||
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
|
||||
|
||||
The args to ``def_partition`` are as follows:
|
||||
|
||||
* ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
|
||||
and returns a suggestion for a new `NamedSharding`. The default
|
||||
implementation is just to return the suggested sharding.
|
||||
* ``partition``: Callable which takes the SPMD suggested partition shapes and
|
||||
partition specs and returns a per-shard lowering function and the final
|
||||
input and output sharding specs (the SPMD partitioner will repartition the
|
||||
inputs to match).
|
||||
* ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
|
||||
from the ``NamedSharding`` chosen for each argument.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
Assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes the
|
||||
discrete Fourier transform of an N-dimensional input along the last dimension, and is batched
|
||||
along the first N-1 dimensions.
|
||||
By default, however, it will ignore the sharding of the input and gather the input on all devices.
|
||||
However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions,
|
||||
this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding
|
||||
along the first `N-1` dimensions, and only gathers the input along the last dimension if needed.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import jax
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.numpy.fft import fft
|
||||
import numpy as np
|
||||
|
||||
# For an N-D input, keeps sharding along the first N-1 dimensions
|
||||
# but replicate along the last dimension
|
||||
def supported_sharding(sharding, shape):
|
||||
rank = len(shape.shape)
|
||||
max_shared_dims = min(len(sharding.spec), rank-1)
|
||||
names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
|
||||
return NamedSharding(sharding.mesh, P(*names))
|
||||
|
||||
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
||||
return fft, \
|
||||
supported_sharding(arg_shardings[0], arg_shapes[0]), \
|
||||
[supported_sharding(arg_shardings[0], arg_shapes[0])]
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
||||
return supported_sharding(arg_shardings[0], arg_shapes[0])
|
||||
|
||||
@custom_partitioning
|
||||
def my_fft(x):
|
||||
return fft(x)
|
||||
|
||||
my_fft.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
partition=partition)
|
||||
|
||||
Now create a 2D array sharded along the first axis, pass it through ``my_fft``
|
||||
and notice how it is still sharded as expected, and identical to the output
|
||||
of ``fft``. However, the output of ``fft`` is replicated
|
||||
|
||||
>>> with Mesh(np.array(jax.devices()), ('x',)):
|
||||
... x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
|
||||
... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
|
||||
... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
... print(pjit_my_fft(y))
|
||||
... print(pjit_fft(y))
|
||||
|
||||
.. code-block::
|
||||
|
||||
# my_fft
|
||||
[[-38.840824 +0.j -40.649452 +11.845365j
|
||||
...
|
||||
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
||||
|
||||
# jax.numpy.fft.fft
|
||||
[[-38.840824 +0.j -40.649452 +11.845365j
|
||||
...
|
||||
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
||||
|
||||
If we dump the HLO using ``XLA_FLAGS="--xla_dump_to=$(pwd)"``, we see that ``pjit_fft`` compiles
|
||||
to
|
||||
|
||||
.. code-block::
|
||||
|
||||
HloModule pjit_fft, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}}
|
||||
|
||||
fused_computation {
|
||||
...
|
||||
ROOT dynamic-slice.2 = c64[16384,1024]{1,0} dynamic-slice(param_1, multiply.3, constant_7), dynamic_slice_sizes={16384,1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
|
||||
}
|
||||
|
||||
ENTRY main.8_spmd {
|
||||
param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1}
|
||||
all-gather = c64[32768,1024]{1,0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
|
||||
fft.1 = c64[32768,1024]{1,0} fft(all-gather), fft_type=FFT, fft_length={1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
|
||||
partition-id = u32[] partition-id(), metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
|
||||
ROOT fusion = c64[16384,1024]{1,0} fusion(fft.1, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
|
||||
}
|
||||
|
||||
Where the ``all-gather`` before the FFT and the dynamic-slice after are both clearly visible.
|
||||
This means that the input is gathered on all devices prior to the FFT, and sliced after.
|
||||
|
||||
``pjit_my_fft``, on the other hand, simply compiles to
|
||||
|
||||
.. code-block::
|
||||
|
||||
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}}
|
||||
|
||||
ENTRY main.5_spmd {
|
||||
param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1}
|
||||
ROOT fft.0 = c64[16384,1024]{1,0} fft(param), fft_type=FFT, fft_length={1024}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=41}
|
||||
}
|
||||
|
||||
where no unnecessary sharding is taking place.
|
||||
|
||||
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
|
||||
|
||||
>>> with Mesh(np.array(jax.devices()), ('x',)):
|
||||
... x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
|
||||
... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
|
||||
... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
||||
... print(pjit_my_fft(y))
|
||||
... print(pjit_fft(y))
|
||||
|
||||
.. code-block::
|
||||
|
||||
# my_fft
|
||||
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
|
||||
... 1422.4502 +7271.4297j -405.84033 -3042.983j
|
||||
-3012.4963 -4287.6343j]
|
||||
|
||||
# jax.numpy.fft.fft
|
||||
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
|
||||
... 1422.4502 +7271.4297j -405.84033 -3042.983j
|
||||
-3012.4963 -4287.6343j]
|
||||
|
||||
In this case, the HLO of ``my_fft`` does show an all-gather and dynamic-slice, since the last dimension
|
||||
is the dimension along which FFTs are calculated.
|
||||
|
||||
.. code-block::
|
||||
|
||||
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16777216]{0})->c64[16777216]{0}}
|
||||
|
||||
fused_computation {
|
||||
...
|
||||
ROOT dynamic-slice.2 = c64[16777216]{0} dynamic-slice(param_1, multiply.3), dynamic_slice_sizes={16777216}, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
|
||||
}
|
||||
|
||||
ENTRY main.5_spmd {
|
||||
param = c64[16777216]{0} parameter(0), sharding={devices=[2]0,1}
|
||||
all-gather = c64[33554432]{0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
|
||||
fft.0 = c64[33554432]{0} fft(all-gather), fft_type=FFT, fft_length={33554432}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(33554432,)]" source_file="doc.py" source_line=51}
|
||||
partition-id = u32[] partition-id(), metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
|
||||
ROOT fusion = c64[16777216]{0} fusion(fft.0, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
|
||||
}
|
||||
|
||||
propagate_user_sharding: Callable which takes the sharding of a user (in the dag)
|
||||
and returns a suggestion for a new NamedSharding. The default
|
||||
implementation is just to return the suggested sharding.
|
||||
partition: Callable which takes the SPMD suggested partition shapes and
|
||||
partition specs and returns a per-shard lowering function and the final
|
||||
input and output sharding specs (the SPMD partitioner will repartition the
|
||||
inputs to match).
|
||||
infer_sharding_from_operands: Callable which computes an output
|
||||
NamedSharding from the NamedSharding chosen for each argument.
|
||||
"""
|
||||
|
||||
def __init__(self, fun):
|
||||
|
Loading…
x
Reference in New Issue
Block a user