Adding @custom_partitioning to jax.experimental API doc

This commit is contained in:
Leopold Cambier 2023-01-19 15:53:01 -08:00
parent 6a69b5a16c
commit 59c71250ee
3 changed files with 188 additions and 27 deletions

View File

@ -0,0 +1,9 @@
``jax.experimental.custom_partitioning`` module
===============================================
.. automodule:: jax.experimental.custom_partitioning
API
---
.. autofunction:: custom_partitioning

View File

@ -21,6 +21,7 @@ Experimental Modules
jax.experimental.pjit
jax.experimental.sparse
jax.experimental.jet
jax.experimental.custom_partitioning
Experimental APIs
-----------------

View File

@ -130,39 +130,190 @@ def _default_propagate_user_shardings(sharding, shape):
class custom_partitioning:
"""Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.
Usage:
```
@custom_partitioning
def f(*args):
return ...
Usage
-----
def propagate_user_sharding(sharding, shape):
'''Update the sharding of the op from a user's sharding.'''
.. code-block:: python
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
def lower_fn(*args):
... builds computation on per-device shapes ...
# result_sharding and arg_shardings may optionally be modified and the
# partitioner will insert collectives to reshape.
return lower_fn, result_sharding, arg_shardings
@custom_partitioning
def f(*args):
return ...
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
'''Compute the result sharding from the sharding of the operands.'''
def propagate_user_sharding(sharding, shape):
'''Update the sharding of the op from a user's sharding.'''
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
```
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
def lower_fn(*args):
... builds computation on per-device shapes ...
# result_sharding and arg_shardings may optionally be modified and the
# partitioner will insert collectives to reshape.
return lower_fn, result_sharding, arg_shardings
The args to def_partition are as follows:
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
'''Compute the result sharding from the sharding of the operands.'''
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
The args to ``def_partition`` are as follows:
* ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
and returns a suggestion for a new `NamedSharding`. The default
implementation is just to return the suggested sharding.
* ``partition``: Callable which takes the SPMD suggested partition shapes and
partition specs and returns a per-shard lowering function and the final
input and output sharding specs (the SPMD partitioner will repartition the
inputs to match).
* ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
from the ``NamedSharding`` chosen for each argument.
Example
-------
Assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes the
discrete Fourier transform of an N-dimensional input along the last dimension, and is batched
along the first N-1 dimensions.
By default, however, it will ignore the sharding of the input and gather the input on all devices.
However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions,
this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding
along the first `N-1` dimensions, and only gathers the input along the last dimension if needed.
.. code-block:: python
import jax
from jax._src.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.experimental.maps import Mesh
from jax.numpy.fft import fft
import numpy as np
# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
rank = len(shape.shape)
max_shared_dims = min(len(sharding.spec), rank-1)
names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
return NamedSharding(sharding.mesh, P(*names))
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
return fft, \
supported_sharding(arg_shardings[0], arg_shapes[0]), \
[supported_sharding(arg_shardings[0], arg_shapes[0])]
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
return supported_sharding(arg_shardings[0], arg_shapes[0])
@custom_partitioning
def my_fft(x):
return fft(x)
my_fft.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition)
Now create a 2D array sharded along the first axis, pass it through ``my_fft``
and notice how it is still sharded as expected, and identical to the output
of ``fft``. However, the output of ``fft`` is replicated
>>> with Mesh(np.array(jax.devices()), ('x',)):
... x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
... print(pjit_my_fft(y))
... print(pjit_fft(y))
.. code-block::
# my_fft
[[-38.840824 +0.j -40.649452 +11.845365j
...
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
# jax.numpy.fft.fft
[[-38.840824 +0.j -40.649452 +11.845365j
...
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
If we dump the HLO using ``XLA_FLAGS="--xla_dump_to=$(pwd)"``, we see that ``pjit_fft`` compiles
to
.. code-block::
HloModule pjit_fft, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}}
fused_computation {
...
ROOT dynamic-slice.2 = c64[16384,1024]{1,0} dynamic-slice(param_1, multiply.3, constant_7), dynamic_slice_sizes={16384,1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
}
ENTRY main.8_spmd {
param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1}
all-gather = c64[32768,1024]{1,0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
fft.1 = c64[32768,1024]{1,0} fft(all-gather), fft_type=FFT, fft_length={1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
partition-id = u32[] partition-id(), metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
ROOT fusion = c64[16384,1024]{1,0} fusion(fft.1, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42}
}
Where the ``all-gather`` before the FFT and the dynamic-slice after are both clearly visible.
This means that the input is gathered on all devices prior to the FFT, and sliced after.
``pjit_my_fft``, on the other hand, simply compiles to
.. code-block::
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}}
ENTRY main.5_spmd {
param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1}
ROOT fft.0 = c64[16384,1024]{1,0} fft(param), fft_type=FFT, fft_length={1024}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=41}
}
where no unnecessary sharding is taking place.
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
>>> with Mesh(np.array(jax.devices()), ('x',)):
... x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
... print(pjit_my_fft(y))
... print(pjit_fft(y))
.. code-block::
# my_fft
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
... 1422.4502 +7271.4297j -405.84033 -3042.983j
-3012.4963 -4287.6343j]
# jax.numpy.fft.fft
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
... 1422.4502 +7271.4297j -405.84033 -3042.983j
-3012.4963 -4287.6343j]
In this case, the HLO of ``my_fft`` does show an all-gather and dynamic-slice, since the last dimension
is the dimension along which FFTs are calculated.
.. code-block::
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16777216]{0})->c64[16777216]{0}}
fused_computation {
...
ROOT dynamic-slice.2 = c64[16777216]{0} dynamic-slice(param_1, multiply.3), dynamic_slice_sizes={16777216}, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
}
ENTRY main.5_spmd {
param = c64[16777216]{0} parameter(0), sharding={devices=[2]0,1}
all-gather = c64[33554432]{0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
fft.0 = c64[33554432]{0} fft(all-gather), fft_type=FFT, fft_length={33554432}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(33554432,)]" source_file="doc.py" source_line=51}
partition-id = u32[] partition-id(), metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
ROOT fusion = c64[16777216]{0} fusion(fft.0, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit(<unnamed wrapped function>)/jit(main)/custom_partitioning[partition=<function partition at 0x7f73a1fbc820> propagate_user_sharding=<function _default_propagate_user_shardings at 0x7f73a1fbc550> infer_sharding_from_operands=<function infer_sharding_from_operands at 0x7f73a1fbc8b0> in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51}
}
propagate_user_sharding: Callable which takes the sharding of a user (in the dag)
and returns a suggestion for a new NamedSharding. The default
implementation is just to return the suggested sharding.
partition: Callable which takes the SPMD suggested partition shapes and
partition specs and returns a per-shard lowering function and the final
input and output sharding specs (the SPMD partitioner will repartition the
inputs to match).
infer_sharding_from_operands: Callable which computes an output
NamedSharding from the NamedSharding chosen for each argument.
"""
def __init__(self, fun):