diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index a691f0cdc..5da12f101 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -295,8 +295,13 @@ def pjit(fun: Callable, resource_env = pxla.thread_resources.env pjit_mesh = resource_env.physical_mesh if pjit_mesh.empty: - raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " - "it's defined at the call site?") + if config.jax_array: + # Don't enforce requiring a mesh when `jax_array` flag is enabled. But + # if mesh is not empty then pjit will respect it. + pass + else: + raise RuntimeError("pjit requires a non-empty mesh! Are you sure that " + "it's defined at the call site?") f = lu.wrap_init(fun) f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) @@ -955,6 +960,7 @@ def _pjit_batcher_for_sharding( return OpShardingSharding(s._device_assignment, new_op) else: assert isinstance(s, OpShardingSharding) + assert not mesh.empty parsed_pspec = parse_flatten_op_sharding(s._op_sharding, mesh)[0] parsed_pspec = parsed_pspec.insert_axis_partitions(dim, val) mps = MeshPspecSharding._from_parsed_pspec(mesh, parsed_pspec) @@ -1613,14 +1619,18 @@ def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[Parti def _get_op_sharding_from_executable( executable) -> Tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]: - input_op_shardings: List[xc.OpSharding] = [] - for s in executable.hlo_modules()[0].spmd_parameters_shardings: - input_op_shardings.extend(_get_op_sharding(s)) + in_op_shardings: List[xc.OpSharding] = [] + parameter_shardings_from_xla = executable.hlo_modules()[0].spmd_parameters_shardings + if parameter_shardings_from_xla is not None: + for s in parameter_shardings_from_xla: + in_op_shardings.extend(_get_op_sharding(s)) - output_op_shardings: Sequence[xc.OpSharding] = _get_op_sharding( - executable.hlo_modules()[0].spmd_output_sharding) + out_op_shardings: List[xc.OpSharding] = [] + output_shardings_from_xla = executable.hlo_modules()[0].spmd_output_sharding + if output_shardings_from_xla is not None: + out_op_shardings = _get_op_sharding(output_shardings_from_xla) # type: ignore - return input_op_shardings, output_op_shardings + return in_op_shardings, out_op_shardings def _get_ppspec_from_executable(executable, mesh) -> Tuple[Sequence[ParsedPartitionSpec], Sequence[ParsedPartitionSpec]]: diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index 4219b1d07..1ef6f551d 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -205,6 +205,13 @@ class MeshPspecSharding(XLACompatibleSharding): return sharding_spec.sharding_proto(special_axes=special_axes) +@functools.lru_cache() +def _get_replicated_op_sharding(): + proto = xc.OpSharding() + proto.type = xc.OpSharding.Type.REPLICATED + return proto + + class SingleDeviceSharding(XLACompatibleSharding): def __init__(self, device: Device): @@ -237,9 +244,7 @@ class SingleDeviceSharding(XLACompatibleSharding): return [self._device] def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]: - proto = xc.OpSharding() - proto.type = xc.OpSharding.Type.REPLICATED - return proto + return _get_replicated_op_sharding() class PmapSharding(XLACompatibleSharding): @@ -338,6 +343,5 @@ class OpShardingSharding(XLACompatibleSharding): @classmethod def get_replicated(cls, device_assignment): - proto = xc.OpSharding() - proto.type = xc.OpSharding.Type.REPLICATED + proto = _get_replicated_op_sharding() return cls(device_assignment, proto) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 0f6afb53f..656a4f6cc 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2630,11 +2630,21 @@ def _get_input_metadata( return shardings, input_indices, input_avals -def _get_op_sharding_shardings_from_executable(xla_executable, device_assignment): +def _get_op_sharding_shardings_from_executable( + xla_executable, device_assignment, num_in_avals, num_out_avals): from jax.experimental import pjit - from jax.experimental.sharding import OpShardingSharding + from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding in_op_shardings, out_op_shardings = pjit._get_op_sharding_from_executable(xla_executable) + + # When the device assignment only has 1 device, SPMD partitioner will not run. + # Hence the op shardings will not be set on the `hlo_module`. In that case, + # just return SingleDeviceShardings since we know the computation is running + # only on 1 device. + if not in_op_shardings and not out_op_shardings and len(device_assignment) == 1: + return ([SingleDeviceSharding(device_assignment[0]) for _ in range(num_in_avals)], + [SingleDeviceSharding(device_assignment[0]) for _ in range(num_out_avals)]) + return ([OpShardingSharding(device_assignment, i) for i in in_op_shardings], [OpShardingSharding(device_assignment, o) for o in out_op_shardings]) @@ -2746,7 +2756,8 @@ class MeshExecutable(stages.XlaExecutable): elif out_shardings and all(_is_unspecified(o) for o in out_shardings): assert mesh is None in_shardings, out_shardings = _get_op_sharding_shardings_from_executable( - xla_executable, first_sharding._device_assignment) + xla_executable, first_sharding._device_assignment, + len(global_in_avals), len(global_out_avals)) in_shardings, input_indices, input_avals = _get_input_metadata( global_in_avals, in_shardings, in_is_global) # type: ignore diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 96d3a8d8a..4cc58b0bc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1443,6 +1443,7 @@ class ArrayPjitTest(jtu.JaxTestCase): ('fully_sharded_output', P('x', 'y'), (2, 4)), ('fully_replicated_output', P(None), (8, 8)), ) + @jax._src.config.jax_array(True) def test_pjit_array_single_output(self, out_axis_resources, shard_shape): global_input_shape = (8, 2) global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) @@ -1450,20 +1451,45 @@ class ArrayPjitTest(jtu.JaxTestCase): input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) - with jax._src.config.jax_array(True): - with global_mesh: - f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding( - global_mesh, out_axis_resources)) - expected_matrix_mul = input_data @ input_data.T + f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding( + global_mesh, out_axis_resources)) + expected_matrix_mul = input_data @ input_data.T - out = f(input_array) - self.assertIsInstance(out, array.Array) - self.assertEqual(out.shape, (8, 8)) - self.assertEqual(out.addressable_shards[0].data.shape, shard_shape) - for s in out.addressable_shards: - self.assertLen(s.data._arrays, 1) - self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index]) - self.assertArraysEqual(out._value, expected_matrix_mul) + out = f(input_array) + self.assertIsInstance(out, array.Array) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.addressable_shards[0].data.shape, shard_shape) + for s in out.addressable_shards: + self.assertLen(s.data._arrays, 1) + self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index]) + self.assertArraysEqual(out._value, expected_matrix_mul) + + @parameterized.named_parameters( + ('fully_sharded_output', P('x', 'y'), (2, 4)), + ('fully_replicated_output', P(None), (8, 8)), + ) + @jax._src.config.jax_array(True) + def test_pjit_array_single_output_with_mesh_context_manager( + self, out_axis_resources, shard_shape): + global_input_shape = (8, 2) + global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh_axes = P('x', 'y') + + input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) + + with global_mesh: + f = pjit(lambda x: x @ x.T, out_axis_resources=MeshPspecSharding( + global_mesh, out_axis_resources)) + expected_matrix_mul = input_data @ input_data.T + + out = f(input_array) + self.assertIsInstance(out, array.Array) + self.assertEqual(out.shape, (8, 8)) + self.assertEqual(out.addressable_shards[0].data.shape, shard_shape) + for s in out.addressable_shards: + self.assertLen(s.data._arrays, 1) + self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index]) + self.assertArraysEqual(out._value, expected_matrix_mul) def test_non_array_input_error(self): input_shape = (8, 2) @@ -1498,6 +1524,7 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) self.assertArraysEqual(out._value, input_data) + @jax._src.config.jax_array(True) def test_unspecified_out_axis_resources(self): def _checks(out, input_data): @@ -1516,21 +1543,20 @@ class ArrayPjitTest(jtu.JaxTestCase): input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) - with jax._src.config.jax_array(True): - with global_mesh: - f = pjit(lambda x: x) + f = pjit(lambda x: x) - out = f(input_array) - _checks(out, input_data) + out = f(input_array) + _checks(out, input_data) - out2 = f(out) - _checks(out2, input_data) + out2 = f(out) + _checks(out2, input_data) @parameterized.named_parameters( ('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)), ('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)), ('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)), ) + @jax._src.config.jax_array(True) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): # Disable on SE runtime type because XLA sharding propagation is not @@ -1549,37 +1575,35 @@ class ArrayPjitTest(jtu.JaxTestCase): spec4 = P(None) a4, _ = create_array(global_input_shape, global_mesh, spec4) - with jax._src.config.jax_array(True): - with global_mesh: - @pjit - def f(tree): - return tree - out_tree = f((a1, (a2, (a3, a4)))) - (out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree) + @pjit + def f(tree): + return tree + out_tree = f((a1, (a2, (a3, a4)))) + (out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree) - self.assertIsInstance(out1, array.Array) - self.assertEqual(out1.shape, (8, 2)) - self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape) - for s in out1.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertIsInstance(out1, array.Array) + self.assertEqual(out1.shape, (8, 2)) + self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape) + for s in out1.addressable_shards: + self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) - self.assertIsInstance(out2, array.Array) - self.assertEqual(out2.shape, (8, 2)) - self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape) - for s in out2.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertIsInstance(out2, array.Array) + self.assertEqual(out2.shape, (8, 2)) + self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape) + for s in out2.addressable_shards: + self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) - self.assertIsInstance(out3, array.Array) - self.assertEqual(out3.shape, (8, 2)) - self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape) - for s in out3.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) + self.assertIsInstance(out3, array.Array) + self.assertEqual(out3.shape, (8, 2)) + self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape) + for s in out3.addressable_shards: + self.assertArraysEqual(s.data._arrays[0], input_data[s.index]) - self.assertIsInstance(out4, array.Array) - self.assertEqual(out4.shape, (8, 2)) - self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape) - for s in out4.addressable_shards: - self.assertArraysEqual(s.data._arrays[0], input_data) + self.assertIsInstance(out4, array.Array) + self.assertEqual(out4.shape, (8, 2)) + self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape) + for s in out4.addressable_shards: + self.assertArraysEqual(s.data._arrays[0], input_data) def test_in_axis_resources_mismatch_error(self): global_input_shape = (8, 2) @@ -1736,6 +1760,34 @@ class ArrayPjitTest(jtu.JaxTestCase): compiled = f.lower(jax.ShapedArray(input_shape, jnp.float32)).compile() compiled(a1) # no error + @jax._src.config.jax_array(True) + def test_pjit_single_device_sharding_add(self): + a = jnp.array([1, 2, 3], dtype=jnp.float32) + b = jnp.array([4, 5, 6], dtype=jnp.float32) + + @pjit + def add(x, y): + return x + y + out = add(a, b) + self.assertIsInstance(out, array.Array) + self.assertArraysEqual(out, a + b) + + out2 = add(out, out) + self.assertIsInstance(out2, array.Array) + self.assertArraysEqual(out2, 2 * (a + b)) + + @jax._src.config.jax_array(True) + def test_pjit_single_device_sharding_mul(self): + a = jnp.arange(16).reshape((8, 2)) + + @pjit + def mul(x): + return x @ x.T + + out = mul(a) + self.assertIsInstance(out, array.Array) + self.assertArraysEqual(out, a @ a.T) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")