mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add static_argnums to custom_partitioning.
Arguments specified by static_argnums cannot contain any jax tracers because they will be passed into the XLA compiler where the lowering information for these tracers is already lost.
This commit is contained in:
parent
fada1c2035
commit
7526d0ea1f
@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
import inspect
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax.experimental import pjit
|
||||
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import ir
|
||||
import jax.interpreters.pxla as pxla
|
||||
@ -26,20 +27,31 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.api_util import argnums_partial
|
||||
|
||||
import weakref
|
||||
|
||||
|
||||
def _resolve_kwargs(fun, args, kwargs):
|
||||
ba = inspect.signature(fun).bind(*args, **kwargs)
|
||||
ba.apply_defaults()
|
||||
if ba.kwargs:
|
||||
raise TypeError("keyword arguments could not be resolved to positions")
|
||||
else:
|
||||
return ba.args
|
||||
|
||||
|
||||
class _ShardingCallbackInfo:
|
||||
|
||||
def __init__(self, propagate_user_sharding, partition, to_mesh_pspec_sharding,
|
||||
infer_sharding_from_operands, module_context, mesh):
|
||||
infer_sharding_from_operands, module_context, mesh, static_args):
|
||||
self.propagate_user_sharding = propagate_user_sharding
|
||||
self.partition = partition
|
||||
self.to_mesh_pspec_sharding = to_mesh_pspec_sharding
|
||||
self.infer_sharding_from_operands = infer_sharding_from_operands
|
||||
self.module_context = module_context
|
||||
self.mesh = mesh
|
||||
self.static_args = static_args
|
||||
|
||||
|
||||
_sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
|
||||
@ -52,17 +64,22 @@ def _to_jax_shape(s):
|
||||
|
||||
|
||||
def _custom_partitioning_propagate_user_sharding(sharding, shape, backend_string):
|
||||
return _sharding_callbacks[backend_string].propagate_user_sharding(sharding, shape)
|
||||
info = _sharding_callbacks[backend_string]
|
||||
if info.propagate_user_sharding is None:
|
||||
return sharding
|
||||
return info.propagate_user_sharding(*info.static_args, sharding, shape)
|
||||
|
||||
|
||||
def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
result_sharding, backend_string):
|
||||
info = _sharding_callbacks[backend_string]
|
||||
lower_fn, result_sharding, arg_shardings = info.partition(
|
||||
*info.static_args,
|
||||
[_to_jax_shape(s) for s in arg_shapes],
|
||||
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
|
||||
_to_jax_shape(result_shape),
|
||||
info.to_mesh_pspec_sharding(result_sharding.to_proto()))
|
||||
info.to_mesh_pspec_sharding(result_sharding.to_proto())
|
||||
)
|
||||
module_context = info.module_context
|
||||
|
||||
def to_hlo_sharding(sharding, shape):
|
||||
@ -95,9 +112,11 @@ def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
||||
info = _sharding_callbacks[backend_string]
|
||||
result_shape = _to_jax_shape(shape)
|
||||
result = info.infer_sharding_from_operands(
|
||||
*info.static_args,
|
||||
[_to_jax_shape(s) for s in arg_shapes],
|
||||
[info.to_mesh_pspec_sharding(s.to_proto()) for s in arg_shardings],
|
||||
result_shape)
|
||||
result_shape
|
||||
)
|
||||
return xc.HloSharding.from_proto(
|
||||
result._to_xla_op_sharding(len(result_shape.shape)))
|
||||
|
||||
@ -107,23 +126,34 @@ custom_partitioning_p.multiple_results = True
|
||||
|
||||
|
||||
def _custom_partitioning_abstract_eval(*avals, call, in_tree, out_tree,
|
||||
propagate_user_sharding, partition,
|
||||
infer_sharding_from_operands):
|
||||
del in_tree, out_tree, propagate_user_sharding, partition, infer_sharding_from_operands
|
||||
propagate_user_sharding, partition,
|
||||
infer_sharding_from_operands,
|
||||
static_args):
|
||||
del in_tree, out_tree, propagate_user_sharding, partition
|
||||
del infer_sharding_from_operands, static_args
|
||||
return call.out_avals
|
||||
|
||||
|
||||
def _custom_partitioning_impl(*args, call, in_tree, out_tree, propagate_user_sharding,
|
||||
partition, infer_sharding_from_operands):
|
||||
del in_tree, out_tree, propagate_user_sharding, partition, infer_sharding_from_operands
|
||||
partition, infer_sharding_from_operands, static_args):
|
||||
del in_tree, out_tree, propagate_user_sharding, partition
|
||||
del infer_sharding_from_operands, static_args
|
||||
return core.jaxpr_as_fun(call)(*args)
|
||||
|
||||
|
||||
custom_partitioning_p.def_abstract_eval(_custom_partitioning_abstract_eval)
|
||||
custom_partitioning_p.def_impl(_custom_partitioning_impl)
|
||||
|
||||
def _default_propagate_user_shardings(sharding, shape):
|
||||
return sharding
|
||||
|
||||
def _check_for_tracers(x):
|
||||
for leaf in tree_util.tree_leaves(x):
|
||||
if isinstance(x, core.Tracer):
|
||||
msg = (
|
||||
"Found a JAX Tracer object passed as an argument to a"
|
||||
"custom_partitioning function in a position indicated as static by"
|
||||
"static_argnums. "
|
||||
)
|
||||
raise UnexpectedTracerError(msg)
|
||||
|
||||
|
||||
@custom_api_util.register_custom_decorator_type
|
||||
@ -163,6 +193,9 @@ class custom_partitioning:
|
||||
* ``infer_sharding_from_operands``: Callable which computes an output ``NamedSharding``
|
||||
from the ``NamedSharding`` chosen for each argument.
|
||||
|
||||
Positional arguments can be specified as static using static_argnums. JAX uses
|
||||
:code:`inspect.signature(fun)` to resolve these positional arguments.
|
||||
|
||||
Example:
|
||||
|
||||
As an example, assume we want to enhance the existing ``jax.numpy.fft.fft``. This function computes
|
||||
@ -277,24 +310,41 @@ class custom_partitioning:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, fun):
|
||||
def __init__(self, fun, static_argnums=()):
|
||||
self.fun = fun
|
||||
self.partition = None
|
||||
self.static_argnums = static_argnums
|
||||
self.propagate_user_sharding = None
|
||||
self.infer_sharding_from_operands = None
|
||||
|
||||
__getattr__ = custom_api_util.forward_attr
|
||||
|
||||
def def_partition(self, partition, infer_sharding_from_operands,
|
||||
propagate_user_sharding=_default_propagate_user_shardings):
|
||||
propagate_user_sharding=None):
|
||||
self.partition = partition
|
||||
self.propagate_user_sharding = propagate_user_sharding
|
||||
self.infer_sharding_from_operands = infer_sharding_from_operands
|
||||
return partition
|
||||
|
||||
def __call__(self, *args):
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
def __call__(self, *args, **kwargs):
|
||||
args = _resolve_kwargs(self.fun, args, kwargs)
|
||||
if self.static_argnums:
|
||||
static_argnums = set(self.static_argnums)
|
||||
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
|
||||
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
||||
f_, dyn_args = argnums_partial(
|
||||
lu.wrap_init(self.fun),
|
||||
dyn_argnums,
|
||||
args,
|
||||
require_static_args_hashable=False,
|
||||
)
|
||||
static_args = [args[i] for i in self.static_argnums]
|
||||
_check_for_tracers(static_args)
|
||||
else:
|
||||
static_args = []
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, False, "custom_partitioning")
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
@ -308,14 +358,17 @@ class custom_partitioning:
|
||||
propagate_user_sharding=self.propagate_user_sharding,
|
||||
infer_sharding_from_operands=self.infer_sharding_from_operands,
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree())
|
||||
out_tree=out_tree(),
|
||||
static_args=static_args
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
|
||||
def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
call, in_tree, out_tree,
|
||||
propagate_user_sharding, partition,
|
||||
infer_sharding_from_operands):
|
||||
infer_sharding_from_operands,
|
||||
static_args):
|
||||
mesh = pxla.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
|
||||
@ -341,7 +394,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding, partition,
|
||||
to_mesh_pspec_sharding,
|
||||
infer_sharding_from_operands,
|
||||
ctx.module_context, mesh)
|
||||
ctx.module_context, mesh, static_args)
|
||||
key = str(id(sharding_callback_info))
|
||||
_sharding_callbacks[key] = sharding_callback_info
|
||||
# We need to make sure `sharding_callback_info` is still alive when the SPMD
|
||||
|
@ -1074,7 +1074,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
if jtu.is_cloud_tpu():
|
||||
raise unittest.SkipTest("Custom partitioning is not supported on libtpu.")
|
||||
|
||||
def partition(arg_shapes, arg_shardings, result_shape, result_sharding):
|
||||
def partition(
|
||||
precision, arg_shapes, arg_shardings, result_shape, result_sharding
|
||||
):
|
||||
self.assertEqual(arg_shardings[0], result_sharding)
|
||||
self.assertEqual(P(('x',)), result_sharding.spec)
|
||||
self.assertEqual(P(('y',)), arg_shardings[1].spec)
|
||||
@ -1087,7 +1089,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
return lower_fn, result_sharding, arg_shardings
|
||||
|
||||
def infer_sharding_from_operands(arg_shapes, arg_shardings, shape):
|
||||
def infer_sharding_from_operands(
|
||||
precision, arg_shapes, arg_shardings, shape
|
||||
):
|
||||
x_shard, y_shard = arg_shardings
|
||||
x_shape, y_shape = arg_shapes
|
||||
x_names = tuple(x_shard.spec) + tuple(
|
||||
@ -1096,9 +1100,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
None for _ in range(len(y_shape.shape) - len(y_shard.spec)))
|
||||
return NamedSharding(y_shard.mesh, P(*(x_names[:-1] + y_names[1:])))
|
||||
|
||||
@custom_partitioning
|
||||
def f(x, y):
|
||||
return x @ y
|
||||
@partial(custom_partitioning, static_argnums=(2,))
|
||||
def f(x, y, precision=None):
|
||||
return jnp.matmul(x, y, precision=precision)
|
||||
|
||||
f.def_partition(
|
||||
infer_sharding_from_operands=infer_sharding_from_operands,
|
||||
|
Loading…
x
Reference in New Issue
Block a user