mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[shape_poly] Lowering sharding annotations in presence of dynamic shapes
Sharding annotations are lowered to custom calls, and in presence of dynamic shapes we must use the `indices_of_shape_operands` attribute to hlo.CustomCall. In order to be able to generate the code to compute the result shapes we must pass the `LoweringRuleContext` and the result abstract value to the lowering helpers that generate the custom calls. The above is easy everywhere, except for the sharding annotations for the inputs and outputs for a function, because we do not yet have a LoweringRuleContext available. This code is tested by tests that are still disabled in sharding_test. They can be enabled once StableHLO improves the support for dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
This commit is contained in:
parent
55fbe1c7b5
commit
961b0655fa
@ -776,7 +776,8 @@ def lower_jaxpr_to_fun(
|
||||
return aval_to_ir_types(aval)
|
||||
|
||||
num_dim_vars = len(ctx.dim_vars)
|
||||
dim_var_types = map(aval_to_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)
|
||||
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
|
||||
dim_var_types = map(aval_to_types, dim_var_avals)
|
||||
|
||||
# Function inputs: *dim_var_values, *tokens, *actual_inputs
|
||||
input_types = map(aval_to_types, jaxpr.in_avals)
|
||||
@ -793,7 +794,10 @@ def lower_jaxpr_to_fun(
|
||||
output_token_types = []
|
||||
num_tokens = len(effects)
|
||||
token_types = [token_type() for _ in effects]
|
||||
token_avals = [core.AbstractToken] * len(effects)
|
||||
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
|
||||
input_types = [*dim_var_types, *token_types, *input_types]
|
||||
output_avals = [core.AbstractToken] * (len(output_token_types) + len(token_types)) + jaxpr.out_avals
|
||||
output_types = [*output_token_types, *token_types, *output_types]
|
||||
if input_output_aliases is not None:
|
||||
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
|
||||
@ -893,15 +897,20 @@ def lower_jaxpr_to_fun(
|
||||
entry_block = func_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
flat_args = entry_block.arguments
|
||||
if not use_sharding_annotations and ir_arg_shardings is not None:
|
||||
flat_args = [a if s is None else wrap_with_sharding_op(a, s)
|
||||
for a, s in zip(flat_args, ir_arg_shardings)]
|
||||
|
||||
unflattened_args = util.unflatten(flat_args, map(len, input_types))
|
||||
# We separate out the dimension variable inputs, the token inputs and
|
||||
# the usual inputs. The dimension variables and token inputs
|
||||
# the regular inputs. The dimension variables and token inputs
|
||||
# will be passed to `jaxpr_subcomp` separately from the `args`.
|
||||
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, num_tokens])
|
||||
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
|
||||
# A lowering context just for function body entry/exit code.
|
||||
entry_lowering_ctx = LoweringRuleContext(
|
||||
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
|
||||
if not use_sharding_annotations and ir_arg_shardings is not None:
|
||||
flat_args = [
|
||||
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
|
||||
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
|
||||
|
||||
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
|
||||
[num_dim_vars, num_tokens])
|
||||
if create_tokens:
|
||||
tokens_in = TokenSet.create(effects)
|
||||
else:
|
||||
@ -932,8 +941,9 @@ def lower_jaxpr_to_fun(
|
||||
outs.append(out)
|
||||
flat_outputs = util.flatten(outs)
|
||||
if not use_sharding_annotations and ir_result_shardings is not None:
|
||||
flat_outputs = [o if s is None else wrap_with_sharding_op(o, s)
|
||||
for o, s in zip(flat_outputs, ir_result_shardings)]
|
||||
flat_outputs = [
|
||||
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
|
||||
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
|
||||
|
||||
func_dialect.ReturnOp(flat_outputs)
|
||||
|
||||
@ -1365,8 +1375,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
|
||||
def _wrap_with_spmd_op(name: str,
|
||||
result_type: ir.Type,
|
||||
ctx: LoweringRuleContext,
|
||||
x: ir.Value,
|
||||
aval_out: core.AbstractValue,
|
||||
sharding_proto: xc.OpSharding,
|
||||
unspecified_dims: Optional[Set[int]] = None):
|
||||
# unspecified_dims indicate dimensions whose shardings are not specified and
|
||||
@ -1376,23 +1387,23 @@ def _wrap_with_spmd_op(name: str,
|
||||
[str(i) for i in sorted(unspecified_dims)]) + "]"
|
||||
else:
|
||||
backend_config = ""
|
||||
op = hlo.CustomCallOp([result_type], [x],
|
||||
call_target_name=ir.StringAttr.get(name),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(backend_config),
|
||||
api_version=i32_attr(1),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
operand_layouts=None,
|
||||
result_layouts=None)
|
||||
result_type = aval_to_ir_type(aval_out)
|
||||
out_shape = aval_out.shape # type: ignore
|
||||
if core.is_constant_shape(out_shape):
|
||||
result_shapes = None
|
||||
else:
|
||||
result_shapes = [shape_tensor(eval_dynamic_shape(ctx, out_shape))]
|
||||
|
||||
op = custom_call(name, [result_type], [x],
|
||||
backend_config=backend_config,
|
||||
has_side_effect=False,
|
||||
api_version=1,
|
||||
result_shapes=result_shapes)
|
||||
set_sharding(op, sharding_proto)
|
||||
return op.result
|
||||
|
||||
def wrap_with_sharding_op(x: ir.Value,
|
||||
sharding_proto: xc.OpSharding,
|
||||
unspecified_dims: Optional[Set[int]] = None):
|
||||
return _wrap_with_spmd_op("Sharding", x.type, x, sharding_proto,
|
||||
unspecified_dims)
|
||||
|
||||
wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
|
||||
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
|
||||
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
|
||||
|
||||
@ -1787,3 +1798,39 @@ def build_xla_computation_helper(
|
||||
return xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
module_to_string(lowering_result.module), use_tuple_args=False,
|
||||
return_tuple=False)
|
||||
|
||||
def custom_call(
|
||||
call_target_name: str,
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
backend_config: Optional[str] = None,
|
||||
has_side_effect: bool = False,
|
||||
result_shapes: Optional[Sequence[ir.Value]] = None,
|
||||
api_version: int = 2,
|
||||
) -> ir.Operation:
|
||||
"""Wraps a hlo.CustomCall
|
||||
|
||||
Args:
|
||||
result_shapes: tensors that represent the result shapes, to be used when
|
||||
the results have dynamic shapes. If not-None, its length must match the
|
||||
number of the results.
|
||||
"""
|
||||
attributes = dict(
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
backend_config=ir.StringAttr.get(
|
||||
"" if backend_config is None else backend_config),
|
||||
api_version=i32_attr(api_version),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
)
|
||||
|
||||
if result_shapes is not None:
|
||||
# We add the result_shapes at the end of the operands, and must pass
|
||||
# the indices_of_output_operands attribute. This attribute is not yet
|
||||
# accepted by the CustomCall constructor, so we use build_generic
|
||||
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
||||
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
|
||||
dtype=np.int64))
|
||||
operands = list(operands) + list(result_shapes)
|
||||
|
||||
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
|
||||
|
@ -1642,10 +1642,9 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
aval_out, = ctx.avals_out
|
||||
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
||||
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
|
||||
proto = manual_proto(aval_in, manual_axes, mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
return mlir.wrap_with_full_to_shard_op(result_type, sx, proto, unspecified_dims=unspecified_dims),
|
||||
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
|
||||
|
||||
shard_to_full_p = core.Primitive('shard_to_full')
|
||||
|
||||
@ -1655,16 +1654,15 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
||||
return untile_aval_nd(mesh.shape, axes, x)
|
||||
|
||||
@partial(mlir.register_lowering, shard_to_full_p)
|
||||
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = manual_proto(aval_in, manual_axes, mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
||||
sx = mlir.wrap_with_sharding_op(x, proto, unspecified_dims=unspecified_dims)
|
||||
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
|
||||
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto()
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
|
||||
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
|
||||
|
||||
@lu.transformation
|
||||
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
|
||||
|
@ -1405,7 +1405,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
|
||||
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
|
||||
sharded_global_in_nodes = [
|
||||
[mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())]
|
||||
[mlir.wrap_with_sharding_op(ctx, node, aval, global_sharding_spec(aval, aval_axes).sharding_proto())]
|
||||
if aval_axes else [node]
|
||||
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
|
||||
]
|
||||
@ -1423,7 +1423,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
|
||||
sharded_global_out_nodes = [
|
||||
mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
mlir.wrap_with_sharding_op(ctx, node, aval, global_sharding_spec(aval, aval_axes).sharding_proto())
|
||||
if aval_axes else node
|
||||
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
||||
]
|
||||
|
@ -1795,6 +1795,7 @@ ad.deflinear2(sharding_constraint_p,
|
||||
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
resource_env, unconstrained_dims):
|
||||
aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
|
||||
# NamedSharding. So convert the GSPMDSharding to NamedSharding
|
||||
@ -1806,8 +1807,8 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
sharding = GSPMDSharding(
|
||||
mps._device_assignment, mps._to_xla_op_sharding(aval.ndim, axis_ctx=axis_ctx))
|
||||
return [
|
||||
mlir.wrap_with_sharding_op(
|
||||
x_node,
|
||||
mlir.wrap_with_sharding_op(ctx,
|
||||
x_node, out_aval,
|
||||
sharding._to_xla_op_sharding(aval.ndim),
|
||||
unspecified_dims=unconstrained_dims)
|
||||
]
|
||||
|
@ -187,7 +187,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
for out_shardings in ("missing", None, "P")
|
||||
)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_basic(self, in_shardings=None, out_shardings="missing"):
|
||||
def test_pjit_basic(self, in_shardings="P", out_shardings="P"):
|
||||
# Ensure that we can distinguish the inputs and outputs by shape
|
||||
def f_jax(x): # f32[10,20] -> f32[20,10]
|
||||
return jnp.sin(x.T)
|
||||
@ -317,17 +317,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(res_tf, res_jax)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}_poly={poly}",
|
||||
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}",
|
||||
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
|
||||
# We add a constraint either with a nested pjit or with a sharding_constraint
|
||||
for nested_pjit in (True, False)
|
||||
for constraint in (None, "P")
|
||||
for poly in (None, "b1,_", "_,b2", "b1,b2")
|
||||
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
||||
)
|
||||
@jtu.with_mesh([("x", 2)])
|
||||
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly=None):
|
||||
if poly is not None:
|
||||
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
|
||||
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"):
|
||||
constraint_sharding = P("x", None) if constraint == "P" else None
|
||||
@partial(pjit.pjit, in_shardings=None,
|
||||
out_shardings=None)
|
||||
@ -697,13 +695,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name=f"_poly={poly}", poly=poly)
|
||||
for poly in (None, "b1,_", "_,b2", "b1,b2")
|
||||
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
||||
)
|
||||
def test_shmap_collective_permute(self, poly=None):
|
||||
if jtu.device_under_test() == "cpu":
|
||||
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
||||
if poly is not None:
|
||||
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
|
||||
mesh = Mesh(self.devices, axis_names=('x'))
|
||||
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
||||
|
||||
|
@ -479,7 +479,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
del check_rep
|
||||
sharded_avals = [v.aval for v in jaxpr.invars]
|
||||
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
|
||||
in_nodes_ = map(partial(_xla_shard, ctx, mesh), in_names, ctx.avals_in,
|
||||
sharded_avals, in_nodes)
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(
|
||||
mesh, frozenset(mesh.axis_names)
|
||||
@ -490,34 +490,34 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
(), *in_nodes_,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
sharded_avals = [v.aval for v in jaxpr.outvars]
|
||||
return map(partial(_xla_unshard, mesh), out_names, sharded_avals,
|
||||
return map(partial(_xla_unshard, ctx, mesh), out_names, sharded_avals,
|
||||
ctx.avals_out, out_nodes_)
|
||||
mlir.register_lowering(shard_map_p, _shard_map_lowering)
|
||||
|
||||
def _xla_shard(mesh, names, aval_in, aval_out, x):
|
||||
def _xla_shard(ctx: mlir.LoweringRuleContext,
|
||||
mesh, names, aval_in, aval_out, x):
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
shard_proto = NamedSharding(
|
||||
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_in.ndim)
|
||||
if core.is_opaque_dtype(aval_in.dtype):
|
||||
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
|
||||
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
|
||||
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=set())
|
||||
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]
|
||||
|
||||
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
|
||||
def _xla_unshard(ctx: mlir.LoweringRuleContext,
|
||||
mesh, names, aval_in, aval_out, xs):
|
||||
x, = xs
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
shard_proto = NamedSharding(
|
||||
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_out.ndim)
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())
|
||||
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, set())
|
||||
|
||||
# Eager evaluation
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user