[shape_poly] Lowering sharding annotations in presence of dynamic shapes

Sharding annotations are lowered to custom calls, and in presence of dynamic shapes
we must use the `indices_of_shape_operands` attribute to hlo.CustomCall.
In order to be able to generate the code to compute the result shapes
we must pass the `LoweringRuleContext` and the result abstract value
to the lowering helpers that generate the custom calls.

The above is easy everywhere, except for the sharding annotations for
the inputs and outputs for a function, because we do not yet have
a LoweringRuleContext available.

This code is tested by tests that are still disabled in sharding_test.
They can be enabled once StableHLO improves the support for
dynamic shapes for custom calls: https://github.com/openxla/stablehlo/issues/1367
This commit is contained in:
George Necula 2023-04-05 09:38:37 +02:00
parent 55fbe1c7b5
commit 961b0655fa
6 changed files with 98 additions and 56 deletions

View File

@ -776,7 +776,8 @@ def lower_jaxpr_to_fun(
return aval_to_ir_types(aval)
num_dim_vars = len(ctx.dim_vars)
dim_var_types = map(aval_to_types, [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)
# Function inputs: *dim_var_values, *tokens, *actual_inputs
input_types = map(aval_to_types, jaxpr.in_avals)
@ -793,7 +794,10 @@ def lower_jaxpr_to_fun(
output_token_types = []
num_tokens = len(effects)
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * len(effects)
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + len(token_types)) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
@ -893,15 +897,20 @@ def lower_jaxpr_to_fun(
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [a if s is None else wrap_with_sharding_op(a, s)
for a, s in zip(flat_args, ir_arg_shardings)]
unflattened_args = util.unflatten(flat_args, map(len, input_types))
# We separate out the dimension variable inputs, the token inputs and
# the usual inputs. The dimension variables and token inputs
# the regular inputs. The dimension variables and token inputs
# will be passed to `jaxpr_subcomp` separately from the `args`.
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, num_tokens])
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)
else:
@ -932,8 +941,9 @@ def lower_jaxpr_to_fun(
outs.append(out)
flat_outputs = util.flatten(outs)
if not use_sharding_annotations and ir_result_shardings is not None:
flat_outputs = [o if s is None else wrap_with_sharding_op(o, s)
for o, s in zip(flat_outputs, ir_result_shardings)]
flat_outputs = [
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
func_dialect.ReturnOp(flat_outputs)
@ -1365,8 +1375,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
def _wrap_with_spmd_op(name: str,
result_type: ir.Type,
ctx: LoweringRuleContext,
x: ir.Value,
aval_out: core.AbstractValue,
sharding_proto: xc.OpSharding,
unspecified_dims: Optional[Set[int]] = None):
# unspecified_dims indicate dimensions whose shardings are not specified and
@ -1376,23 +1387,23 @@ def _wrap_with_spmd_op(name: str,
[str(i) for i in sorted(unspecified_dims)]) + "]"
else:
backend_config = ""
op = hlo.CustomCallOp([result_type], [x],
call_target_name=ir.StringAttr.get(name),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(backend_config),
api_version=i32_attr(1),
called_computations=ir.ArrayAttr.get([]),
operand_layouts=None,
result_layouts=None)
result_type = aval_to_ir_type(aval_out)
out_shape = aval_out.shape # type: ignore
if core.is_constant_shape(out_shape):
result_shapes = None
else:
result_shapes = [shape_tensor(eval_dynamic_shape(ctx, out_shape))]
op = custom_call(name, [result_type], [x],
backend_config=backend_config,
has_side_effect=False,
api_version=1,
result_shapes=result_shapes)
set_sharding(op, sharding_proto)
return op.result
def wrap_with_sharding_op(x: ir.Value,
sharding_proto: xc.OpSharding,
unspecified_dims: Optional[Set[int]] = None):
return _wrap_with_spmd_op("Sharding", x.type, x, sharding_proto,
unspecified_dims)
wrap_with_sharding_op = partial(_wrap_with_spmd_op, "Sharding")
wrap_with_full_to_shard_op = partial(_wrap_with_spmd_op, "SPMDFullToShardShape")
wrap_with_shard_to_full_op = partial(_wrap_with_spmd_op, "SPMDShardToFullShape")
@ -1787,3 +1798,39 @@ def build_xla_computation_helper(
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
return_tuple=False)
def custom_call(
call_target_name: str,
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
backend_config: Optional[str] = None,
has_side_effect: bool = False,
result_shapes: Optional[Sequence[ir.Value]] = None,
api_version: int = 2,
) -> ir.Operation:
"""Wraps a hlo.CustomCall
Args:
result_shapes: tensors that represent the result shapes, to be used when
the results have dynamic shapes. If not-None, its length must match the
number of the results.
"""
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(
"" if backend_config is None else backend_config),
api_version=i32_attr(api_version),
called_computations=ir.ArrayAttr.get([]),
)
if result_shapes is not None:
# We add the result_shapes at the end of the operands, and must pass
# the indices_of_output_operands attribute. This attribute is not yet
# accepted by the CustomCall constructor, so we use build_generic
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
np.asarray(list(range(len(operands), len(operands) + len(result_shapes))),
dtype=np.int64))
operands = list(operands) + list(result_shapes)
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)

View File

@ -1642,10 +1642,9 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
aval_out, = ctx.avals_out
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, sharding_proto, unspecified_dims=unspecified_dims)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
proto = manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
return mlir.wrap_with_full_to_shard_op(result_type, sx, proto, unspecified_dims=unspecified_dims),
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
shard_to_full_p = core.Primitive('shard_to_full')
@ -1655,16 +1654,15 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
return untile_aval_nd(mesh.shape, axes, x)
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
sx = mlir.wrap_with_sharding_op(x, proto, unspecified_dims=unspecified_dims)
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto()
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
@lu.transformation
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],

View File

@ -1405,7 +1405,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
sharded_global_in_nodes = [
[mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())]
[mlir.wrap_with_sharding_op(ctx, node, aval, global_sharding_spec(aval, aval_axes).sharding_proto())]
if aval_axes else [node]
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
]
@ -1423,7 +1423,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
dim_var_values=ctx.dim_var_values)
sharded_global_out_nodes = [
mlir.wrap_with_sharding_op(node, global_sharding_spec(aval, aval_axes).sharding_proto())
mlir.wrap_with_sharding_op(ctx, node, aval, global_sharding_spec(aval, aval_axes).sharding_proto())
if aval_axes else node
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
]

View File

@ -1795,6 +1795,7 @@ ad.deflinear2(sharding_constraint_p,
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
resource_env, unconstrained_dims):
aval, = ctx.avals_in
out_aval, = ctx.avals_out
axis_ctx = ctx.module_context.axis_context
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
# NamedSharding. So convert the GSPMDSharding to NamedSharding
@ -1806,8 +1807,8 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
sharding = GSPMDSharding(
mps._device_assignment, mps._to_xla_op_sharding(aval.ndim, axis_ctx=axis_ctx))
return [
mlir.wrap_with_sharding_op(
x_node,
mlir.wrap_with_sharding_op(ctx,
x_node, out_aval,
sharding._to_xla_op_sharding(aval.ndim),
unspecified_dims=unconstrained_dims)
]

View File

@ -187,7 +187,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
for out_shardings in ("missing", None, "P")
)
@jtu.with_mesh([("x", 2)])
def test_pjit_basic(self, in_shardings=None, out_shardings="missing"):
def test_pjit_basic(self, in_shardings="P", out_shardings="P"):
# Ensure that we can distinguish the inputs and outputs by shape
def f_jax(x): # f32[10,20] -> f32[20,10]
return jnp.sin(x.T)
@ -317,17 +317,15 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(res_tf, res_jax)
@parameterized.named_parameters(
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint=}_poly={poly}",
dict(testcase_name=f"_nested_pjit={nested_pjit}_constraint={constraint}_poly={poly}",
nested_pjit=nested_pjit, constraint=constraint, poly=poly)
# We add a constraint either with a nested pjit or with a sharding_constraint
for nested_pjit in (True, False)
for constraint in (None, "P")
for poly in (None, "b1,_", "_,b2", "b1,b2")
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
)
@jtu.with_mesh([("x", 2)])
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly=None):
if poly is not None:
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
def test_pjit_sharding_constraint(self, nested_pjit=True, constraint="P", poly="2*b1,b2"):
constraint_sharding = P("x", None) if constraint == "P" else None
@partial(pjit.pjit, in_shardings=None,
out_shardings=None)
@ -697,13 +695,11 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
@parameterized.named_parameters(
dict(testcase_name=f"_poly={poly}", poly=poly)
for poly in (None, "b1,_", "_,b2", "b1,b2")
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
)
def test_shmap_collective_permute(self, poly=None):
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
if poly is not None:
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
mesh = Mesh(self.devices, axis_names=('x'))
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))

View File

@ -479,7 +479,7 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
check_rep):
del check_rep
sharded_avals = [v.aval for v in jaxpr.invars]
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
in_nodes_ = map(partial(_xla_shard, ctx, mesh), in_names, ctx.avals_in,
sharded_avals, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names)
@ -490,34 +490,34 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
(), *in_nodes_,
dim_var_values=ctx.dim_var_values)
sharded_avals = [v.aval for v in jaxpr.outvars]
return map(partial(_xla_unshard, mesh), out_names, sharded_avals,
return map(partial(_xla_unshard, ctx, mesh), out_names, sharded_avals,
ctx.avals_out, out_nodes_)
mlir.register_lowering(shard_map_p, _shard_map_lowering)
def _xla_shard(mesh, names, aval_in, aval_out, x):
def _xla_shard(ctx: mlir.LoweringRuleContext,
mesh, names, aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_in.ndim)
if core.is_opaque_dtype(aval_in.dtype):
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
return [mlir.wrap_with_full_to_shard_op(result_type, sx, manual_proto, set())]
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=set())
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]
def _xla_unshard(mesh, names, aval_in, aval_out, xs):
def _xla_unshard(ctx: mlir.LoweringRuleContext,
mesh, names, aval_in, aval_out, xs):
x, = xs
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_out.ndim)
if core.is_opaque_dtype(aval_out.dtype):
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, set())
# Eager evaluation