From a1797170af81772b338c43192e789aab155c6763 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 10 Apr 2023 08:42:18 -0700 Subject: [PATCH] Create a proper NamedSharding without `None` as the pspec. This happens when users pass `None` as the out_shardings/in_shardings and pjit should convert it to a proper PartitionSpec. PiperOrigin-RevId: 523125287 --- jax/_src/interpreters/pxla.py | 4 +++- jax/experimental/pjit.py | 1 - tests/pjit_test.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 1aba390b8..7edb69d3e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3124,8 +3124,10 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr( @lru_cache() def create_mesh_pspec_sharding( - mesh: Mesh, pspec: PartitionSpec, parsed_pspec=None + mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None ) -> sharding_impls.NamedSharding: + if pspec is None: + pspec = PartitionSpec() return sharding_impls.NamedSharding(mesh, pspec, parsed_pspec) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 6d59c3472..f379da1d5 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -29,7 +29,6 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources, _get_op_sharding_from_executable, _get_pspec_from_executable, _pjit_lower_cached, _pjit_lower, _pjit_jaxpr, - _create_mesh_pspec_sharding_from_parsed_pspec, _process_in_axis_resources) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 3e1fcc04e..6a7aa43fb 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3130,6 +3130,21 @@ class ArrayPjitTest(jtu.JaxTestCase): self.assertIsInstance(out4.sharding, SingleDeviceSharding) self.assertEqual(out4.device(), jax.devices()[1]) + def test_none_out_sharding(self): + mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + x = jnp.arange(8) + with mesh: + out = pjit(lambda x: x * 2, out_shardings=None)(x) + self.assertEqual(out.sharding.mesh, mesh) + self.assertIsInstance(out.sharding, NamedSharding) + self.assertEqual(out.sharding.spec, P()) + + x2 = jax.device_put(x, NamedSharding(mesh, P())) + out2 = pjit(lambda x: x * 2)(x2) + self.assertIsInstance(out2.sharding, NamedSharding) + self.assertEqual(out2.sharding.mesh, mesh) + self.assertEqual(out2.sharding.spec, P()) + def test_sharding_preserved_apply_primitive(self): mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x'))