Add static_argnums to custom_partitioning.

Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
This commit is contained in:
Parker Schuh 2023-02-03 11:30:31 -08:00
parent fada1c2035
commit 7526d0ea1f
2 changed files with 82 additions and 25 deletions

View File

@ -13,11 +13,12 @@
# limitations under the License.
import jax
import inspect
from jax import core
from jax import tree_util
from jax._src import linear_util as lu
from jax.experimental import pjit
from jax.errors import UnexpectedTracerError
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
import jax.interpreters.pxla as pxla
@ -26,20 +27,31 @@ from jax.interpreters import partial_eval as pe
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.api_util import argnums_partial
import weakref
def _resolve_kwargs(fun, args, kwargs):
ba = inspect.signature(fun).bind(*args, **kwargs)
ba.apply_defaults()
if ba.kwargs:
raise TypeError("keyword arguments could not be resolved to positions")
else:
return ba.args
class _ShardingCallbackInfo:
def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
infer_sharding_from_operands, module_context, mesh):
infer_sharding_from_operands, module_context, mesh, static_args):
self.propagate_user_sharding = propagate_user_sharding
self.partition = partition
self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
self.infer_sharding_from_operands = infer_sharding_from_operands
self.module_context = module_context
self.mesh = mesh
self.static_args = static_args
_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
@ -52,17 +64,22 @@ def _to_jax_shape(s):
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
return _sharding_callbacks[backend_string].propagate_user_sharding(sharding, shape)
info = _sharding_callbacks[backend_string]
if info.propagate_user_sharding is None:
return sharding
return info.propagate_user_sharding(*info.static_args, sharding, shape)
def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
result_sharding, backend_string):
info = _sharding_callbacks[backend_string]
lower_fn, result_sharding, arg_shardings = info.partition(
*info.static_args,
[_to_jax_shape(s) for s in arg_shapes],
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
_to_jax_shape(result_shape),
info.to_mesh_pspec_sharding(result_sharding.to_proto()))
info.to_mesh_pspec_sharding(result_sharding.to_proto())
)
module_context = info.module_context
def to_hlo_sharding(sharding, shape):
@ -95,9 +112,11 @@ def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
info = _sharding_callbacks[backend_string]
result_shape = _to_jax_shape(shape)
result = info.infer_sharding_from_operands(
*info.static_args,
[_to_jax_shape(s) for s in arg_shapes],
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
result_shape)
result_shape
)
return xc.HloSharding.from_proto(
result._to_xla_op_sharding(len(result_shape.shape)))
@ -107,23 +126,34 @@ custom_partitioning_p.multiple_results = True
def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
propagate_user_sharding, partition,
infer_sharding_from_operands):
del in_tree, out_tree, propagate_user_sharding, partition, infer_sharding_from_operands
propagate_user_sharding, partition,
infer_sharding_from_operands,
static_args):
del in_tree, out_tree, propagate_user_sharding, partition
del infer_sharding_from_operands, static_args
return call.out_avals
def _custom_partitioning_impl(*args, call, in_tree, out_tree, propagate_user_sharding,
partition, infer_sharding_from_operands):
del in_tree, out_tree, propagate_user_sharding, partition, infer_sharding_from_operands
partition, infer_sharding_from_operands, static_args):
del in_tree, out_tree, propagate_user_sharding, partition
del infer_sharding_from_operands, static_args
return core.jaxpr_as_fun(call)(*args)
custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval)
custom_partitioning_p.def_impl(_custom_partitioning_impl)
def _default_propagate_user_shardings(sharding, shape):
return sharding
def _check_for_tracers(x):
for leaf in tree_util.tree_leaves(x):
if isinstance(x, core.Tracer):
msg = (
"Found a JAX Tracer object passed as an argument to a"
"custom_partitioning function in a position indicated as static by"
"static_argnums. "
)
raise UnexpectedTracerError(msg)
@custom_api_util.register_custom_decorator_type
@ -163,6 +193,9 @@ class custom_partitioning:
* ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
from the ``NamedSharding`` chosen for each argument.
Positional arguments can be specified as static using static_argnums. JAX uses
:code:`inspect.signature(fun)` to resolve these positional arguments.
Example:
As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes
@ -277,24 +310,41 @@ class custom_partitioning:
"""
def __init__(self, fun):
def __init__(self, fun, static_argnums=()):
self.fun = fun
self.partition = None
self.static_argnums = static_argnums
self.propagate_user_sharding = None
self.infer_sharding_from_operands = None
__getattr__ = custom_api_util.forward_attr
def def_partition(self, partition, infer_sharding_from_operands,
propagate_user_sharding=_default_propagate_user_shardings):
propagate_user_sharding=None):
self.partition = partition
self.propagate_user_sharding = propagate_user_sharding
self.infer_sharding_from_operands = infer_sharding_from_operands
return partition
def __call__(self, *args):
args_flat, in_tree = tree_util.tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
def __call__(self, *args, **kwargs):
args = _resolve_kwargs(self.fun, args, kwargs)
if self.static_argnums:
static_argnums = set(self.static_argnums)
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f_, dyn_args = argnums_partial(
lu.wrap_init(self.fun),
dyn_argnums,
args,
require_static_args_hashable=False,
)
static_args = [args[i] for i in self.static_argnums]
_check_for_tracers(static_args)
else:
static_args = []
f_, dyn_args = lu.wrap_init(self.fun), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_partitioning")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
@ -308,14 +358,17 @@ class custom_partitioning:
propagate_user_sharding=self.propagate_user_sharding,
infer_sharding_from_operands=self.infer_sharding_from_operands,
in_tree=in_tree,
out_tree=out_tree())
out_tree=out_tree(),
static_args=static_args
)
return tree_util.tree_unflatten(out_tree(), out_flat)
def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
call, in_tree, out_tree,
propagate_user_sharding, partition,
infer_sharding_from_operands):
infer_sharding_from_operands,
static_args):
mesh = pxla.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
@ -341,7 +394,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
to_mesh_pspec_sharding,
infer_sharding_from_operands,
ctx.module_context, mesh)
ctx.module_context, mesh, static_args)
key = str(id(sharding_callback_info))
_sharding_callbacks[key] = sharding_callback_info
# We need to make sure `sharding_callback_info` is still alive when the SPMD

View File

@ -1074,7 +1074,9 @@ class PJitTest(jtu.BufferDonationTestCase):
if jtu.is_cloud_tpu():
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
def partition(
precision, arg_shapes, arg_shardings, result_shape, result_sharding
):
self.assertEqual(arg_shardings[0], result_sharding)
self.assertEqual(P(('x',)), result_sharding.spec)
self.assertEqual(P(('y',)), arg_shardings[1].spec)
@ -1087,7 +1089,9 @@ class PJitTest(jtu.BufferDonationTestCase):
return lower_fn, result_sharding, arg_shardings
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
def infer_sharding_from_operands(
precision, arg_shapes, arg_shardings, shape
):
x_shard, y_shard = arg_shardings
x_shape, y_shape = arg_shapes
x_names = tuple(x_shard.spec) + tuple(
@ -1096,9 +1100,9 @@ class PJitTest(jtu.BufferDonationTestCase):
None for _ in range(len(y_shape.shape) - len(y_shard.spec)))
return NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
@custom_partitioning
def f(x, y):
return x @ y
@partial(custom_partitioning, static_argnums=(2,))
def f(x, y, precision=None):
return jnp.matmul(x, y, precision=precision)
f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,