diff --git a/docs/jax.experimental.custom_partitioning.rst b/docs/jax.experimental.custom_partitioning.rst new file mode 100644 index 000000000..41948e12e --- /dev/null +++ b/docs/jax.experimental.custom_partitioning.rst @@ -0,0 +1,9 @@ +``jax.experimental.custom_partitioning`` module +=============================================== + +.. automodule:: jax.experimental.custom_partitioning + +API +--- + +.. autofunction:: custom_partitioning diff --git a/docs/jax.experimental.rst b/docs/jax.experimental.rst index da51168db..2cbb364f6 100644 --- a/docs/jax.experimental.rst +++ b/docs/jax.experimental.rst @@ -21,6 +21,7 @@ Experimental Modules jax.experimental.pjit jax.experimental.sparse jax.experimental.jet + jax.experimental.custom_partitioning Experimental APIs ----------------- diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 6e5ff35c3..479fa3ea9 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -130,39 +130,190 @@ def _default_propagate_user_shardings(sharding, shape): class custom_partitioning: """Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules. - Usage: - ``` - @custom_partitioning - def f(*args): - return ... + Usage + ----- - def propagate_user_sharding(sharding, shape): - '''Update the sharding of the op from a user's sharding.''' + .. code-block:: python - def partition(arg_shapes, arg_shardings, result_shape, result_sharding): - def lower_fn(*args): - ... builds computation on per-device shapes ... - # result_sharding and arg_shardings may optionally be modified and the - # partitioner will insert collectives to reshape. - return lower_fn, result_sharding, arg_shardings + @custom_partitioning + def f(*args): + return ... - def infer_sharding_from_operands(arg_shapes, arg_shardings, shape): - '''Compute the result sharding from the sharding of the operands.''' + def propagate_user_sharding(sharding, shape): + '''Update the sharding of the op from a user's sharding.''' - f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands) - ``` + def partition(arg_shapes, arg_shardings, result_shape, result_sharding): + def lower_fn(*args): + ... builds computation on per-device shapes ... + # result_sharding and arg_shardings may optionally be modified and the + # partitioner will insert collectives to reshape. + return lower_fn, result_sharding, arg_shardings - The args to def_partition are as follows: + def infer_sharding_from_operands(arg_shapes, arg_shardings, shape): + '''Compute the result sharding from the sharding of the operands.''' + + f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands) + + The args to ``def_partition`` are as follows: + + * ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag) + and returns a suggestion for a new `NamedSharding`. The default + implementation is just to return the suggested sharding. + * ``partition``: Callable which takes the SPMD suggested partition shapes and + partition specs and returns a per-shard lowering function and the final + input and output sharding specs (the SPMD partitioner will repartition the + inputs to match). + * ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding`` + from the ``NamedSharding`` chosen for each argument. + + Example + ------- + + Assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes the + discrete Fourier transform of an N-dimensional input along the last dimension, and is batched + along the first N-1 dimensions. + By default, however, it will ignore the sharding of the input and gather the input on all devices. + However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions, + this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding + along the first `N-1` dimensions, and only gathers the input along the last dimension if needed. + + .. code-block:: python + + import jax + from jax._src.sharding import NamedSharding + from jax.experimental.custom_partitioning import custom_partitioning + from jax.experimental.pjit import pjit + from jax.sharding import PartitionSpec as P + from jax.experimental.maps import Mesh + from jax.numpy.fft import fft + import numpy as np + + # For an N-D input, keeps sharding along the first N-1 dimensions + # but replicate along the last dimension + def supported_sharding(sharding, shape): + rank = len(shape.shape) + max_shared_dims = min(len(sharding.spec), rank-1) + names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) + return NamedSharding(sharding.mesh, P(*names)) + + def partition(arg_shapes, arg_shardings, result_shape, result_sharding): + return fft, \ + supported_sharding(arg_shardings[0], arg_shapes[0]), \ + [supported_sharding(arg_shardings[0], arg_shapes[0])] + + def infer_sharding_from_operands(arg_shapes, arg_shardings, shape): + return supported_sharding(arg_shardings[0], arg_shapes[0]) + + @custom_partitioning + def my_fft(x): + return fft(x) + + my_fft.def_partition( + infer_sharding_from_operands=infer_sharding_from_operands, + partition=partition) + + Now create a 2D array sharded along the first axis, pass it through ``my_fft`` + and notice how it is still sharded as expected, and identical to the output + of ``fft``. However, the output of ``fft`` is replicated + + >>> with Mesh(np.array(jax.devices()), ('x',)): + ... x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) + ... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x) + ... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x')) + ... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x')) + ... print(pjit_my_fft(y)) + ... print(pjit_fft(y)) + + .. code-block:: + + # my_fft + [[-38.840824 +0.j -40.649452 +11.845365j + ... + -1.6937828 +0.8402481j 15.999859 -4.0156755j]] + + # jax.numpy.fft.fft + [[-38.840824 +0.j -40.649452 +11.845365j + ... + -1.6937828 +0.8402481j 15.999859 -4.0156755j]] + + If we dump the HLO using ``XLA_FLAGS="--xla_dump_to=$(pwd)"``, we see that ``pjit_fft`` compiles + to + + .. code-block:: + + HloModule pjit_fft, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}} + + fused_computation { + ... + ROOT dynamic-slice.2 = c64[16384,1024]{1,0} dynamic-slice(param_1, multiply.3, constant_7), dynamic_slice_sizes={16384,1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42} + } + + ENTRY main.8_spmd { + param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1} + all-gather = c64[32768,1024]{1,0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42} + fft.1 = c64[32768,1024]{1,0} fft(all-gather), fft_type=FFT, fft_length={1024}, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42} + partition-id = u32[] partition-id(), metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42} + ROOT fusion = c64[16384,1024]{1,0} fusion(fft.1, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit(fft)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=42} + } + + Where the ``all-gather`` before the FFT and the dynamic-slice after are both clearly visible. + This means that the input is gathered on all devices prior to the FFT, and sliced after. + + ``pjit_my_fft``, on the other hand, simply compiles to + + .. code-block:: + + HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16384,1024]{1,0})->c64[16384,1024]{1,0}} + + ENTRY main.5_spmd { + param = c64[16384,1024]{1,0} parameter(0), sharding={devices=[2,1]0,1} + ROOT fft.0 = c64[16384,1024]{1,0} fft(param), fft_type=FFT, fft_length={1024}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="doc.py" source_line=41} + } + + where no unnecessary sharding is taking place. + + Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays. + + >>> with Mesh(np.array(jax.devices()), ('x',)): + ... x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) + ... y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x) + ... pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x')) + ... pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x')) + ... print(pjit_my_fft(y)) + ... print(pjit_fft(y)) + + .. code-block:: + + # my_fft + [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j + ... 1422.4502 +7271.4297j -405.84033 -3042.983j + -3012.4963 -4287.6343j] + + # jax.numpy.fft.fft + [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j + ... 1422.4502 +7271.4297j -405.84033 -3042.983j + -3012.4963 -4287.6343j] + + In this case, the HLO of ``my_fft`` does show an all-gather and dynamic-slice, since the last dimension + is the dimension along which FFTs are calculated. + + .. code-block:: + + HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(c64[16777216]{0})->c64[16777216]{0}} + + fused_computation { + ... + ROOT dynamic-slice.2 = c64[16777216]{0} dynamic-slice(param_1, multiply.3), dynamic_slice_sizes={16777216}, metadata={op_name="pjit()/jit(main)/custom_partitioning[partition= propagate_user_sharding= infer_sharding_from_operands= in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51} + } + + ENTRY main.5_spmd { + param = c64[16777216]{0} parameter(0), sharding={devices=[2]0,1} + all-gather = c64[33554432]{0} all-gather(param), channel_id=1, replica_groups={{0,1}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit()/jit(main)/custom_partitioning[partition= propagate_user_sharding= infer_sharding_from_operands= in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51} + fft.0 = c64[33554432]{0} fft(all-gather), fft_type=FFT, fft_length={33554432}, metadata={op_name="jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(33554432,)]" source_file="doc.py" source_line=51} + partition-id = u32[] partition-id(), metadata={op_name="pjit()/jit(main)/custom_partitioning[partition= propagate_user_sharding= infer_sharding_from_operands= in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51} + ROOT fusion = c64[16777216]{0} fusion(fft.0, partition-id), kind=kLoop, calls=fused_computation, metadata={op_name="pjit()/jit(main)/custom_partitioning[partition= propagate_user_sharding= infer_sharding_from_operands= in_tree=PyTreeDef((*,)) out_tree=PyTreeDef(*)]" source_file="doc.py" source_line=51} + } - propagate_user_sharding: Callable which takes the sharding of a user (in the dag) - and returns a suggestion for a new NamedSharding. The default - implementation is just to return the suggested sharding. - partition: Callable which takes the SPMD suggested partition shapes and - partition specs and returns a per-shard lowering function and the final - input and output sharding specs (the SPMD partitioner will repartition the - inputs to match). - infer_sharding_from_operands: Callable which computes an output - NamedSharding from the NamedSharding chosen for each argument. """ def __init__(self, fun):