mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Create a proper NamedSharding without None
as the pspec. This happens when users pass None
as the out_shardings/in_shardings and pjit should convert it to a proper PartitionSpec.
PiperOrigin-RevId: 523125287
This commit is contained in:
parent
efc8300d02
commit
a1797170af
@ -3124,8 +3124,10 @@ def _compile_replicated_mesh_executable_from_trivial_jaxpr(
|
||||
|
||||
@lru_cache()
|
||||
def create_mesh_pspec_sharding(
|
||||
mesh: Mesh, pspec: PartitionSpec, parsed_pspec=None
|
||||
mesh: Mesh, pspec: Optional[PartitionSpec], parsed_pspec=None
|
||||
) -> sharding_impls.NamedSharding:
|
||||
if pspec is None:
|
||||
pspec = PartitionSpec()
|
||||
return sharding_impls.NamedSharding(mesh, pspec, parsed_pspec)
|
||||
|
||||
|
||||
|
@ -29,7 +29,6 @@ from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
||||
_get_op_sharding_from_executable,
|
||||
_get_pspec_from_executable, _pjit_lower_cached,
|
||||
_pjit_lower, _pjit_jaxpr,
|
||||
_create_mesh_pspec_sharding_from_parsed_pspec,
|
||||
_process_in_axis_resources)
|
||||
|
||||
|
||||
|
@ -3130,6 +3130,21 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out4.sharding, SingleDeviceSharding)
|
||||
self.assertEqual(out4.device(), jax.devices()[1])
|
||||
|
||||
def test_none_out_sharding(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
x = jnp.arange(8)
|
||||
with mesh:
|
||||
out = pjit(lambda x: x * 2, out_shardings=None)(x)
|
||||
self.assertEqual(out.sharding.mesh, mesh)
|
||||
self.assertIsInstance(out.sharding, NamedSharding)
|
||||
self.assertEqual(out.sharding.spec, P())
|
||||
|
||||
x2 = jax.device_put(x, NamedSharding(mesh, P()))
|
||||
out2 = pjit(lambda x: x * 2)(x2)
|
||||
self.assertIsInstance(out2.sharding, NamedSharding)
|
||||
self.assertEqual(out2.sharding.mesh, mesh)
|
||||
self.assertEqual(out2.sharding.spec, P())
|
||||
|
||||
def test_sharding_preserved_apply_primitive(self):
|
||||
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
|
||||
ns = NamedSharding(mesh, P('x'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user