mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 07:56:06 +00:00

Arguments specified by static_argnums cannot contain any jax tracers because they will be passed into the XLA compiler where the lowering information for these tracers is already lost.
437 lines
17 KiB
Python
437 lines
17 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import jax
|
|
import inspect
|
|
from jax import core
|
|
from jax import tree_util
|
|
from jax._src import linear_util as lu
|
|
from jax.experimental import pjit
|
|
from jax.errors import UnexpectedTracerError
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.lib.mlir import ir
|
|
import jax.interpreters.pxla as pxla
|
|
from jax.interpreters import mlir
|
|
from jax.interpreters import partial_eval as pe
|
|
from jax._src import custom_api_util
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.api_util import flatten_fun_nokwargs
|
|
from jax._src.api_util import argnums_partial
|
|
|
|
import weakref
|
|
|
|
|
|
def _resolve_kwargs(fun, args, kwargs):
|
|
ba = inspect.signature(fun).bind(*args, **kwargs)
|
|
ba.apply_defaults()
|
|
if ba.kwargs:
|
|
raise TypeError("keyword arguments could not be resolved to positions")
|
|
else:
|
|
return ba.args
|
|
|
|
|
|
class _ShardingCallbackInfo:
|
|
|
|
def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
|
|
infer_sharding_from_operands, module_context, mesh, static_args):
|
|
self.propagate_user_sharding = propagate_user_sharding
|
|
self.partition = partition
|
|
self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
|
|
self.infer_sharding_from_operands = infer_sharding_from_operands
|
|
self.module_context = module_context
|
|
self.mesh = mesh
|
|
self.static_args = static_args
|
|
|
|
|
|
_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
|
|
|
|
_CUSTOM_PARTITIONING_CALL_NAME = "CustomSPMDPartitioning"
|
|
|
|
|
|
def _to_jax_shape(s):
|
|
return jax.core.ShapedArray(s.dimensions(), s.numpy_dtype())
|
|
|
|
|
|
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
|
|
info = _sharding_callbacks[backend_string]
|
|
if info.propagate_user_sharding is None:
|
|
return sharding
|
|
return info.propagate_user_sharding(*info.static_args, sharding, shape)
|
|
|
|
|
|
def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
|
result_sharding, backend_string):
|
|
info = _sharding_callbacks[backend_string]
|
|
lower_fn, result_sharding, arg_shardings = info.partition(
|
|
*info.static_args,
|
|
[_to_jax_shape(s) for s in arg_shapes],
|
|
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
|
|
_to_jax_shape(result_shape),
|
|
info.to_mesh_pspec_sharding(result_sharding.to_proto())
|
|
)
|
|
module_context = info.module_context
|
|
|
|
def to_hlo_sharding(sharding, shape):
|
|
return xc.HloSharding.from_proto(
|
|
sharding._to_xla_op_sharding(len(shape.dimensions())))
|
|
|
|
result_sharding = to_hlo_sharding(result_sharding, result_shape)
|
|
arg_shardings = [
|
|
to_hlo_sharding(sharding, s)
|
|
for sharding, s in zip(arg_shardings, arg_shapes)
|
|
]
|
|
tiled_args = [
|
|
_to_jax_shape(sharding.tile(s))
|
|
for sharding, s in zip(arg_shardings, arg_shapes)
|
|
]
|
|
closed_jaxpr = jax.make_jaxpr(
|
|
lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args)
|
|
axis_context = mlir.SPMDAxisContext(info.mesh)
|
|
built = mlir.build_xla_computation_helper(
|
|
closed_jaxpr,
|
|
name="tmp_xla_computation",
|
|
platform=module_context.platform,
|
|
backend_or_name=module_context.backend_or_name,
|
|
axis_context=axis_context.extend_manual(frozenset(info.mesh.axis_names)))
|
|
return built, arg_shardings, result_sharding
|
|
|
|
|
|
def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
|
shape, backend_string):
|
|
info = _sharding_callbacks[backend_string]
|
|
result_shape = _to_jax_shape(shape)
|
|
result = info.infer_sharding_from_operands(
|
|
*info.static_args,
|
|
[_to_jax_shape(s) for s in arg_shapes],
|
|
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
|
|
result_shape
|
|
)
|
|
return xc.HloSharding.from_proto(
|
|
result._to_xla_op_sharding(len(result_shape.shape)))
|
|
|
|
|
|
custom_partitioning_p = core.Primitive("custom_partitioning")
|
|
custom_partitioning_p.multiple_results = True
|
|
|
|
|
|
def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
|
|
propagate_user_sharding, partition,
|
|
infer_sharding_from_operands,
|
|
static_args):
|
|
del in_tree, out_tree, propagate_user_sharding, partition
|
|
del infer_sharding_from_operands, static_args
|
|
return call.out_avals
|
|
|
|
|
|
def _custom_partitioning_impl(*args, call, in_tree, out_tree, propagate_user_sharding,
|
|
partition, infer_sharding_from_operands, static_args):
|
|
del in_tree, out_tree, propagate_user_sharding, partition
|
|
del infer_sharding_from_operands, static_args
|
|
return core.jaxpr_as_fun(call)(*args)
|
|
|
|
|
|
custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval)
|
|
custom_partitioning_p.def_impl(_custom_partitioning_impl)
|
|
|
|
|
|
def _check_for_tracers(x):
|
|
for leaf in tree_util.tree_leaves(x):
|
|
if isinstance(x, core.Tracer):
|
|
msg = (
|
|
"Found a JAX Tracer object passed as an argument to a"
|
|
"custom_partitioning function in a position indicated as static by"
|
|
"static_argnums. "
|
|
)
|
|
raise UnexpectedTracerError(msg)
|
|
|
|
|
|
@custom_api_util.register_custom_decorator_type
|
|
class custom_partitioning:
|
|
"""Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.
|
|
|
|
.. code-block:: python
|
|
|
|
@custom_partitioning
|
|
def f(*args):
|
|
return ...
|
|
|
|
def propagate_user_sharding(sharding, shape):
|
|
'''Update the sharding of the op from a user's sharding.'''
|
|
|
|
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
|
def lower_fn(*args):
|
|
... builds computation on per-device shapes ...
|
|
# result_sharding and arg_shardings may optionally be modified and the
|
|
# partitioner will insert collectives to reshape.
|
|
return lower_fn, result_sharding, arg_shardings
|
|
|
|
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
|
'''Compute the result sharding from the sharding of the operands.'''
|
|
|
|
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
|
|
|
|
The args to ``def_partition`` are as follows:
|
|
|
|
* ``propagate_user_sharding``: Callable which takes the sharding of a user (in the dag)
|
|
and returns a suggestion for a new `NamedSharding`. The default
|
|
implementation is just to return the suggested sharding.
|
|
* ``partition``: Callable which takes the SPMD suggested partition shapes and
|
|
partition specs and returns a per-shard lowering function and the final
|
|
input and output sharding specs (the SPMD partitioner will repartition the
|
|
inputs to match).
|
|
* ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
|
|
from the ``NamedSharding`` chosen for each argument.
|
|
|
|
Positional arguments can be specified as static using static_argnums. JAX uses
|
|
:code:`inspect.signature(fun)` to resolve these positional arguments.
|
|
|
|
Example:
|
|
|
|
As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes
|
|
the discrete Fourier transform of an N-dimensional input along the last dimension, and is batched
|
|
along the first N-1 dimensions.
|
|
By default, however, it will ignore the sharding of the input and gather the input on all devices.
|
|
However, since ``jax.numpy.fft.fft`` is batched along the first N-1 dimensions,
|
|
this is unnecessary. We will create a new ``my_fft`` op that, instead, does not alter the sharding
|
|
along the first `N-1` dimensions, and only gathers the input along the last dimension if needed.
|
|
|
|
.. code-block:: python
|
|
|
|
import jax
|
|
from jax._src.sharding import NamedSharding
|
|
from jax.experimental.custom_partitioning import custom_partitioning
|
|
from jax.experimental.pjit import pjit
|
|
from jax.sharding import PartitionSpec as P
|
|
from jax.experimental.maps import Mesh
|
|
from jax.numpy.fft import fft
|
|
import regex as re
|
|
import numpy as np
|
|
|
|
# Pattern to detect all-gather or dynamic-slice in the generated HLO
|
|
_PATTERN = '(dynamic-slice|all-gather)'
|
|
|
|
# For an N-D input, keeps sharding along the first N-1 dimensions
|
|
# but replicate along the last dimension
|
|
def supported_sharding(sharding, shape):
|
|
rank = len(shape.shape)
|
|
max_shared_dims = min(len(sharding.spec), rank-1)
|
|
names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
|
|
return NamedSharding(sharding.mesh, P(*names))
|
|
|
|
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
|
return fft, \
|
|
supported_sharding(arg_shardings[0], arg_shapes[0]), \
|
|
[supported_sharding(arg_shardings[0], arg_shapes[0])]
|
|
|
|
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
|
return supported_sharding(arg_shardings[0], arg_shapes[0])
|
|
|
|
@custom_partitioning
|
|
def my_fft(x):
|
|
return fft(x)
|
|
|
|
my_fft.def_partition(
|
|
infer_sharding_from_operands=infer_sharding_from_operands,
|
|
partition=partition)
|
|
|
|
Now create a 2D array sharded along the first axis, pass it through ``my_fft``
|
|
and notice how it is still sharded as expected, and identical to the output
|
|
of ``fft``. However, inspecting the HLO
|
|
(using ``lower(x).compile().runtime_executable().hlo_modules()``) reveals that
|
|
``my_fft`` does not create any all-gather or dynamic-slice, while ``fft`` does.
|
|
|
|
.. code-block::
|
|
|
|
with Mesh(np.array(jax.devices()), ('x',)):
|
|
x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
|
|
y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
|
|
pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
|
pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
|
print(pjit_my_fft(y))
|
|
print(pjit_fft(y))
|
|
# dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
|
|
assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
|
|
# dynamic-slice or all-gather are present in the HLO for fft
|
|
assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
|
|
|
|
.. code-block::
|
|
|
|
# my_fft
|
|
[[-38.840824 +0.j -40.649452 +11.845365j
|
|
...
|
|
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
|
|
|
# jax.numpy.fft.fft
|
|
[[-38.840824 +0.j -40.649452 +11.845365j
|
|
...
|
|
-1.6937828 +0.8402481j 15.999859 -4.0156755j]]
|
|
|
|
Because of the logic in ``supported_sharding``, ``my_fft`` also works on 1-dimensional arrays.
|
|
However, in this case, the HLO of ``my_fft`` does show a a dynamic-slice, since the last dimension
|
|
is the dimension along which FFTs are calculated and needs to be replicated on all devices before
|
|
the computation can be done.
|
|
|
|
.. code-block::
|
|
|
|
with Mesh(np.array(jax.devices()), ('x',)):
|
|
x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
|
|
y = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=P('x'))(x)
|
|
pjit_my_fft = pjit(my_fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
|
pjit_fft = pjit(fft, in_axis_resources=P('x'), out_axis_resources=P('x'))
|
|
print(pjit_my_fft(y))
|
|
print(pjit_fft(y))
|
|
# dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
|
|
assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
|
|
# dynamic-slice or all-gather are present in the HLO for fft
|
|
assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
|
|
|
|
.. code-block::
|
|
|
|
# my_fft
|
|
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
|
|
... 1422.4502 +7271.4297j -405.84033 -3042.983j
|
|
-3012.4963 -4287.6343j]
|
|
|
|
# jax.numpy.fft.fft
|
|
[ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j
|
|
... 1422.4502 +7271.4297j -405.84033 -3042.983j
|
|
-3012.4963 -4287.6343j]
|
|
|
|
"""
|
|
|
|
def __init__(self, fun, static_argnums=()):
|
|
self.fun = fun
|
|
self.partition = None
|
|
self.static_argnums = static_argnums
|
|
self.propagate_user_sharding = None
|
|
self.infer_sharding_from_operands = None
|
|
|
|
__getattr__ = custom_api_util.forward_attr
|
|
|
|
def def_partition(self, partition, infer_sharding_from_operands,
|
|
propagate_user_sharding=None):
|
|
self.partition = partition
|
|
self.propagate_user_sharding = propagate_user_sharding
|
|
self.infer_sharding_from_operands = infer_sharding_from_operands
|
|
return partition
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
args = _resolve_kwargs(self.fun, args, kwargs)
|
|
if self.static_argnums:
|
|
static_argnums = set(self.static_argnums)
|
|
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
|
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
|
f_, dyn_args = argnums_partial(
|
|
lu.wrap_init(self.fun),
|
|
dyn_argnums,
|
|
args,
|
|
require_static_args_hashable=False,
|
|
)
|
|
static_args = [args[i] for i in self.static_argnums]
|
|
_check_for_tracers(static_args)
|
|
else:
|
|
static_args = []
|
|
f_, dyn_args = lu.wrap_init(self.fun), args
|
|
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
|
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
|
|
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
|
debug = pe.debug_info(self.fun, in_tree, False, "custom_partitioning")
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
|
assert not len(consts)
|
|
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
|
out_flat = custom_partitioning_p.bind(
|
|
*consts,
|
|
*args_flat,
|
|
call=closed_call,
|
|
partition=self.partition,
|
|
propagate_user_sharding=self.propagate_user_sharding,
|
|
infer_sharding_from_operands=self.infer_sharding_from_operands,
|
|
in_tree=in_tree,
|
|
out_tree=out_tree(),
|
|
static_args=static_args
|
|
)
|
|
return tree_util.tree_unflatten(out_tree(), out_flat)
|
|
|
|
|
|
def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
|
call, in_tree, out_tree,
|
|
propagate_user_sharding, partition,
|
|
infer_sharding_from_operands,
|
|
static_args):
|
|
mesh = pxla.thread_resources.env.physical_mesh
|
|
axis_context = ctx.module_context.axis_context
|
|
|
|
if isinstance(axis_context, mlir.ShardingContext):
|
|
devices = axis_context.device_assignment
|
|
elif isinstance(axis_context, mlir.SPMDAxisContext):
|
|
devices = list(axis_context.mesh.devices.flat)
|
|
else:
|
|
devices = None
|
|
|
|
if not devices or len(devices) == 1:
|
|
return mlir.lower_fun(
|
|
core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
|
|
|
|
def to_mesh_pspec_sharding(op_sharding: xc.OpSharding):
|
|
if mesh.empty:
|
|
from jax._src.sharding import OpShardingSharding
|
|
return OpShardingSharding(devices, op_sharding)
|
|
pspec = pjit.parse_flatten_op_sharding(op_sharding,
|
|
mesh)[0].get_partition_spec()
|
|
return pjit.NamedSharding(mesh, pspec)
|
|
|
|
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
|
to_mesh_pspec_sharding,
|
|
infer_sharding_from_operands,
|
|
ctx.module_context, mesh, static_args)
|
|
key = str(id(sharding_callback_info))
|
|
_sharding_callbacks[key] = sharding_callback_info
|
|
# We need to make sure `sharding_callback_info` is still alive when the SPMD
|
|
# partitioner runs so we keep it alive by attaching it to the executable.
|
|
ctx.module_context.add_keepalive(sharding_callback_info)
|
|
|
|
mlir_shapes = [mlir.aval_to_ir_types(s) for s in call.out_avals]
|
|
if len(mlir_shapes) == 1:
|
|
out_type = mlir_shapes[0]
|
|
else:
|
|
out_type = [ir.TupleType.get_tuple(mlir_shapes)]
|
|
|
|
out = hlo.CustomCallOp(
|
|
out_type,
|
|
list(values),
|
|
call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
|
|
has_side_effect=ir.BoolAttr.get(False),
|
|
api_version=mlir.i32_attr(2),
|
|
called_computations=ir.ArrayAttr.get([]),
|
|
backend_config=ir.StringAttr.get(key),
|
|
operand_layouts=None,
|
|
result_layouts=None)
|
|
if len(mlir_shapes) == 1:
|
|
return [out.result]
|
|
else:
|
|
return [
|
|
hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
|
|
for i in range(len(mlir_shapes))
|
|
]
|
|
|
|
|
|
mlir.register_lowering(custom_partitioning_p,
|
|
_custom_partitioning_lowering_rule)
|
|
|
|
xc.register_custom_call_partitioner( # pytype: disable=module-attr
|
|
_CUSTOM_PARTITIONING_CALL_NAME,
|
|
_custom_partitioning_propagate_user_sharding,
|
|
_custom_partitioning_partition,
|
|
_custom_partitioning_infer_sharding_from_operands, True)
|