Create a proper NamedSharding without None as the pspec. This happens when users pass None as the out_shardings/in_shardings and pjit should convert it to a proper PartitionSpec.

PiperOrigin-RevId: 523125287
This commit is contained in:
Yash Katariya 2023-04-10 08:42:18 -07:00 committed by jax authors
parent efc8300d02
commit a1797170af
3 changed files with 18 additions and 2 deletions

View File

@ -3124,8 +3124,10 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
@lru_cache()
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec, parsed_pspec=None
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None
) -> sharding_impls.NamedSharding:
if pspec is None:
pspec = PartitionSpec()
return sharding_impls.NamedSharding(mesh, pspec, parsed_pspec)

View File

@ -29,7 +29,6 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
_get_op_sharding_from_executable,
_get_pspec_from_executable, _pjit_lower_cached,
_pjit_lower, _pjit_jaxpr,
_create_mesh_pspec_sharding_from_parsed_pspec,
_process_in_axis_resources)

View File

@ -3130,6 +3130,21 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
self.assertEqual(out4.device(), jax.devices()[1])
def test_none_out_sharding(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
x = jnp.arange(8)
with mesh:
out = pjit(lambda x: x * 2, out_shardings=None)(x)
self.assertEqual(out.sharding.mesh, mesh)
self.assertIsInstance(out.sharding, NamedSharding)
self.assertEqual(out.sharding.spec, P())
x2 = jax.device_put(x, NamedSharding(mesh, P()))
out2 = pjit(lambda x: x * 2)(x2)
self.assertIsInstance(out2.sharding, NamedSharding)
self.assertEqual(out2.sharding.mesh, mesh)
self.assertEqual(out2.sharding.spec, P())
def test_sharding_preserved_apply_primitive(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
ns = NamedSharding(mesh, P('x'))