mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Delete mesh_sharding_specs
from JAX
PiperOrigin-RevId: 595164505
This commit is contained in:
parent
697f17adf1
commit
5192fca09b
@ -69,7 +69,7 @@ from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
is_unspecified, is_unspecified_or_auto
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
@ -1510,12 +1510,15 @@ def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
# TODO: Can we short-circuit for replicated values? Probably not.
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
sharding_proto = mesh_sharding_specs(
|
||||
mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto().to_proto()
|
||||
sharding_proto = (
|
||||
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
|
||||
._to_xla_hlo_sharding(aval_in.ndim).to_proto())
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto, unspecified_dims=unspecified_dims)
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, sharding_proto,
|
||||
unspecified_dims=unspecified_dims)
|
||||
proto = manual_proto(aval_in, manual_axes, mesh)
|
||||
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto, unspecified_dims=unspecified_dims),
|
||||
return (mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, proto,
|
||||
unspecified_dims=unspecified_dims),)
|
||||
|
||||
shard_to_full_p = core.Primitive('shard_to_full')
|
||||
|
||||
@ -1531,10 +1534,13 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
|
||||
aval_out, = ctx.avals_out
|
||||
proto = manual_proto(aval_in, manual_axes, mesh) # type: ignore
|
||||
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values()) # type: ignore
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto, unspecified_dims=unspecified_dims)
|
||||
sharding_proto = mesh_sharding_specs(
|
||||
mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto().to_proto()
|
||||
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto, unspecified_dims),
|
||||
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, proto,
|
||||
unspecified_dims=unspecified_dims)
|
||||
sharding_proto = (
|
||||
sharding_impls.NamedSharding(mesh, array_mapping_to_axis_resources(axes))
|
||||
._to_xla_hlo_sharding(aval_out.ndim).to_proto())
|
||||
return (mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, sharding_proto,
|
||||
unspecified_dims),)
|
||||
|
||||
@lu.transformation
|
||||
def vtile_manual(manual_axes: frozenset[sharding_impls.MeshAxisName],
|
||||
@ -3068,30 +3074,6 @@ def resource_typecheck(jaxpr, resource_env, axis_resources, what_jaxpr_thunk):
|
||||
_check_aval(v.aval, what_thunk)
|
||||
|
||||
|
||||
def mesh_sharding_specs(axis_sizes, axis_names, allow_uneven_axes=False):
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
|
||||
# NOTE: This takes in the non-sharded avals!
|
||||
def mk_sharding_spec(aval, aval_axes):
|
||||
if aval is core.abstract_token:
|
||||
assert not aval_axes
|
||||
return ShardingSpec([], [Replicated(axis_size) for axis_size in axis_sizes.values()])
|
||||
aval_shape = list(aval.shape)
|
||||
# NOTE: sorted is stable, which is important when multiple resources
|
||||
# map to the same axis.
|
||||
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
||||
if not allow_uneven_axes:
|
||||
if aval_shape[axis] % axis_sizes[name] != 0:
|
||||
raise ValueError(
|
||||
f'The aval shape on dimension {axis} is {aval_shape[axis]} and '
|
||||
f'the size of axis {name} is {axis_sizes[name]}. The aval shape % '
|
||||
'axis size should be zero but got '
|
||||
f'{aval_shape[axis] % axis_sizes[name]}')
|
||||
aval_shape[axis] //= axis_sizes[name]
|
||||
return sharding_specs.make_sharding_spec(
|
||||
axis_sizes, mesh_axis_pos, len(aval.shape), aval_axes)
|
||||
return mk_sharding_spec
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
|
@ -1417,11 +1417,11 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
global_in_avals = ctx.avals_in
|
||||
vectorized_jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
|
||||
global_sharding_spec = pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)
|
||||
sharded_global_in_nodes = [
|
||||
[mlir.wrap_with_sharding_op(
|
||||
ctx, node, aval,
|
||||
global_sharding_spec(aval, aval_axes).sharding_proto().to_proto())]
|
||||
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
|
||||
)._to_xla_hlo_sharding(aval.ndim).to_proto())]
|
||||
if aval_axes else [node]
|
||||
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
|
||||
]
|
||||
@ -1441,7 +1441,8 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
sharded_global_out_nodes = [
|
||||
mlir.wrap_with_sharding_op(
|
||||
ctx, node, aval,
|
||||
global_sharding_spec(aval, aval_axes).sharding_proto().to_proto())
|
||||
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
|
||||
)._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
if aval_axes else node
|
||||
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
||||
]
|
||||
|
@ -251,24 +251,6 @@ def spec_to_indices(shape: Sequence[int],
|
||||
return tuple(spec.indices(shape).flat) # type: ignore
|
||||
|
||||
|
||||
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
|
||||
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
|
||||
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
|
||||
next_sharded_axis = 0
|
||||
# NOTE: sorted is stable, which is important when multiple resources
|
||||
# map to the same axis.
|
||||
for name, axis in sorted(aval_axes.items(), key=lambda x: x[1]):
|
||||
chunked = sharding[axis]
|
||||
if isinstance(chunked, NoSharding):
|
||||
chunked = Chunked([])
|
||||
sharding[axis] = Chunked(list(chunked.chunks) + [axis_sizes[name]])
|
||||
assert isinstance(mesh_mapping[mesh_axis_pos[name]], Replicated), \
|
||||
"Value mapped to the same mesh axis twice"
|
||||
mesh_mapping[mesh_axis_pos[name]] = ShardedAxis(next_sharded_axis)
|
||||
next_sharded_axis += 1
|
||||
return ShardingSpec(sharding, mesh_mapping)
|
||||
|
||||
|
||||
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
|
||||
map_axis: int | None) -> ShardingSpec:
|
||||
"""Sharding spec for arguments or results of a pmap.
|
||||
|
@ -22,7 +22,6 @@ from jax._src.interpreters.pxla import (
|
||||
global_aval_to_result_handler as global_aval_to_result_handler,
|
||||
global_avals_to_results_handler as global_avals_to_results_handler,
|
||||
global_result_handlers as global_result_handlers,
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
parallel_callable as parallel_callable,
|
||||
shard_arg as shard_arg,
|
||||
shard_args as shard_args,
|
||||
|
@ -4281,17 +4281,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertTrue(all(i == (slice(None),) * aval.ndim
|
||||
for out, aval in safe_zip(out_indices, in_avals) for i in out))
|
||||
|
||||
def test_mesh_sharding_spec(self):
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
array_mapping = pxla.get_array_mapping(P('x', 'y'))
|
||||
aval = core.ShapedArray((1, 1), jnp.int32)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'The aval shape on dimension 0 is 1 and the size of axis x is 4. The '
|
||||
'aval shape % axis size should be zero but got 1'
|
||||
):
|
||||
pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
|
||||
("only_unspecified", UNSPECIFIED),
|
||||
|
Loading…
x
Reference in New Issue
Block a user