mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Don't depend on mesh for UNSPECIFIED
. Use OpShardingSharding
for that since its now available and pjit accepts it.
PiperOrigin-RevId: 465641117
This commit is contained in:
parent
cc4bd0f283
commit
4b6d4a4ef7
@ -849,12 +849,9 @@ def _pjit_lower(
|
||||
for o in out_shardings
|
||||
)
|
||||
|
||||
# TODO(yashkatariya): UNSPECIFIED should go through lower_sharding_computation.
|
||||
# Also the `jaxpr_has_primitive` for xmap is temporary until xmap supports
|
||||
# sharding instances.
|
||||
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
|
||||
# because `xmap` only supports SPMDAxisContext right now.
|
||||
if (pxla._check_if_any_auto_or_unspecified(in_shardings + out_shardings) or
|
||||
if (pxla._check_if_any_auto(it.chain(in_shardings, out_shardings)) or
|
||||
dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap')):
|
||||
return pxla.lower_mesh_computation(
|
||||
fun, 'pjit', name, resource_env.physical_mesh,
|
||||
@ -1636,7 +1633,7 @@ def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartit
|
||||
return in_ppspec, out_ppspec
|
||||
|
||||
|
||||
def _get_sharding_from_executable(
|
||||
def _get_pspec_from_executable(
|
||||
executable, mesh: pxla.Mesh
|
||||
) -> Tuple[Tuple[PartitionSpec, ...], Tuple[PartitionSpec, ...]]:
|
||||
in_ppspec, out_ppspec = _get_ppspec_from_executable(executable, mesh)
|
||||
|
@ -2280,20 +2280,12 @@ class TileManual:
|
||||
TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
|
||||
def _check_if_any_auto(shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
|
||||
def _check_if_any_auto(shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
|
||||
for s in shardings:
|
||||
if _is_auto(s):
|
||||
return True
|
||||
return False
|
||||
|
||||
# TODO(yashkatariya): Remove this once UNSPECIFIED can be used without mesh.
|
||||
def _check_if_any_auto_or_unspecified(
|
||||
shardings: Sequence[Union[XLACompatibleSharding, _AUTOAxisResource, _UnspecifiedValue]]) -> bool:
|
||||
for s in shardings:
|
||||
if _is_auto(s) or _is_unspecified(s):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _UnconstrainedPartitionSingleton:
|
||||
|
||||
@ -2329,10 +2321,17 @@ class PartitionSpec(tuple):
|
||||
|
||||
|
||||
def _get_backend_from_shardings(
|
||||
shardings: Sequence[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
|
||||
device_set = shardings[0]._device_assignment
|
||||
assert len(device_set) > 0
|
||||
return xb.get_device_backend(device_set[0]), shardings[0]
|
||||
shardings: Iterable[XLACompatibleSharding]) -> Tuple[xb.XlaBackend, XLACompatibleSharding]:
|
||||
da = None
|
||||
first_sharding = None
|
||||
for s in shardings:
|
||||
if _is_unspecified(s):
|
||||
continue
|
||||
da = s._device_assignment
|
||||
first_sharding = s
|
||||
break
|
||||
assert len(da) > 0 # type: ignore
|
||||
return xb.get_device_backend(da[0]), first_sharding # type: ignore
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
@ -2348,7 +2347,7 @@ def lower_sharding_computation(
|
||||
# Device assignment across all inputs and outputs should be the same. This
|
||||
# is checked in pjit.
|
||||
backend, first_sharding = _get_backend_from_shardings(
|
||||
in_shardings + out_shardings) # type: ignore
|
||||
it.chain(in_shardings, out_shardings)) # type: ignore
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
||||
@ -2383,7 +2382,7 @@ def lower_sharding_computation(
|
||||
for aval, i in safe_zip(global_in_avals, in_shardings)]
|
||||
# TODO(yashkatariya): Fix the HLO produced if out_partitions is
|
||||
# [None, OpShardingProto] has the sharding annotations.
|
||||
out_op_shardings = [o._to_xla_op_sharding(aval.ndim)
|
||||
out_op_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
||||
for aval, o in safe_zip(global_out_avals, out_shardings)]
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
axis_ctx = mlir.ShardingContext(first_sharding)
|
||||
@ -2631,11 +2630,22 @@ def _get_input_metadata(
|
||||
return shardings, input_indices, input_avals
|
||||
|
||||
|
||||
def _get_shardings_from_executable(xla_executable, mesh):
|
||||
def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment):
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.sharding import OpShardingSharding
|
||||
|
||||
in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable)
|
||||
return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings],
|
||||
[OpShardingSharding(device_assignment, o) for o in out_op_shardings])
|
||||
|
||||
|
||||
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
|
||||
# without mesh.
|
||||
def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
|
||||
in_pspec, out_pspec = pjit._get_sharding_from_executable(xla_executable, mesh)
|
||||
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
||||
return ([MeshPspecSharding(mesh, i) for i in in_pspec],
|
||||
[MeshPspecSharding(mesh, o) for o in out_pspec])
|
||||
|
||||
@ -2658,10 +2668,8 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
@staticmethod
|
||||
def from_hlo(name: str,
|
||||
computation: Union[ir.Module, xc.XlaComputation],
|
||||
# mesh only needs to be set if in_shardings and out_shardings
|
||||
# contain AUTO or UNSPECIFIED (unspecified is temporary here).
|
||||
# TODO(yashkatariya): Remove `mesh` from here once AUTO and
|
||||
# UNSPECIFIED work without mesh.
|
||||
# TODO(yashkatariya): Remove `mesh` from here once AUTO can work
|
||||
# without mesh.
|
||||
mesh: Optional[Mesh],
|
||||
global_in_avals: Sequence[ShapedArray],
|
||||
global_out_avals: Sequence[ShapedArray],
|
||||
@ -2677,20 +2685,16 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
unordered_effects: List[core.Effect],
|
||||
host_callbacks: List[Any],
|
||||
keepalive: Any) -> MeshExecutable:
|
||||
auto_or_unspecified = (
|
||||
auto_spmd_lowering or
|
||||
(out_shardings and all(_is_unspecified(o) for o in out_shardings)))
|
||||
|
||||
if auto_or_unspecified:
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
else:
|
||||
backend, first_sharding = _get_backend_from_shardings(
|
||||
in_shardings + out_shardings) # type: ignore
|
||||
it.chain(in_shardings, out_shardings)) # type: ignore
|
||||
|
||||
dev: np.ndarray
|
||||
if auto_or_unspecified:
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None and spmd_lowering
|
||||
dev = mesh.devices
|
||||
num_replicas, num_partitions = 1, mesh.size
|
||||
@ -2735,12 +2739,14 @@ class MeshExecutable(stages.XlaExecutable):
|
||||
xla_executable = dispatch.compile_or_get_cached(
|
||||
backend, computation, compile_options, host_callbacks)
|
||||
|
||||
if auto_or_unspecified:
|
||||
# TODO(yashkatariya): Make this work for UNSPECIFIED without mesh by
|
||||
# returning `OpShardingSharding`.
|
||||
if auto_spmd_lowering:
|
||||
assert mesh is not None
|
||||
in_shardings, out_shardings = _get_shardings_from_executable(
|
||||
in_shardings, out_shardings = _get_mesh_pspec_shardings_from_executable(
|
||||
xla_executable, mesh)
|
||||
elif out_shardings and all(_is_unspecified(o) for o in out_shardings):
|
||||
assert mesh is None
|
||||
in_shardings, out_shardings = _get_op_sharding_shardings_from_executable(
|
||||
xla_executable, first_sharding._device_assignment)
|
||||
|
||||
in_shardings, input_indices, input_avals = _get_input_metadata(
|
||||
global_in_avals, in_shardings, in_is_global) # type: ignore
|
||||
|
@ -1499,6 +1499,17 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
def test_unspecified_out_axis_resources(self):
|
||||
|
||||
def _checks(out, input_data):
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertIsInstance(out.sharding, OpShardingSharding)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
global_input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mesh_axes = P('x', 'y')
|
||||
@ -1510,13 +1521,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
f = pjit(lambda x: x)
|
||||
|
||||
out = f(input_array)
|
||||
self.assertIsInstance(out, array.Array)
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
_checks(out, input_data)
|
||||
|
||||
out2 = f(out)
|
||||
_checks(out2, input_data)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user