Merge pull request #19017 from gnecula:export_shard_map

PiperOrigin-RevId: 591614088
This commit is contained in:
jax authors 2023-12-16 22:26:37 -08:00
commit e0cc9879d5

View File

@ -28,6 +28,7 @@ from jax import numpy as jnp
from jax import tree_util
from jax.experimental.export import export
from jax.experimental.export import serialization
from jax.experimental.shard_map import shard_map
from jax.experimental import pjit
from jax.sharding import NamedSharding
from jax.sharding import Mesh
@ -804,6 +805,63 @@ class JaxExportTest(jtu.JaxTestCase):
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
)(a)
@jtu.parameterized_filterable(
kwargs=[
dict(testcase_name=f"_poly={poly}", poly=poly)
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
])
def test_shard_map_collective_permute(self, poly=None):
if len(jax.devices()) < 2:
self.skipTest("Test requires at least 2 local devices")
devices = np.array(jax.devices()[:2]) # use 2 devices
mesh = Mesh(devices, axis_names=("x",))
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
@functools.partial(
pjit.pjit,
in_shardings=NamedSharding(mesh, P("x", None),),
out_shardings=NamedSharding(mesh, P("x", None)))
@functools.partial(
shard_map, mesh=mesh,
in_specs=(P("x", None),), out_specs=P("x", None))
def f_jax(b): # b: f32[2, 4]
axis_size = lax.psum(1, "x")
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
return lax.ppermute(b, "x", perm=perm)
args_specs = export.args_specs((a,), polymorphic_shapes=poly)
exp = get_exported(f_jax)(*args_specs)
# Test JAX native execution
res_jax = f_jax(a)
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
b0, b1 = b1, b0
expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0
self.assertAllClose(res_jax, expected)
self.assertLen(res_jax.addressable_shards, len(devices))
# Test reloaded execution.
f_r = export.call_exported(exp)
with self.assertRaisesRegex(
Exception,
"Exported module .* was lowered for 2 devices and is "
"called in a context with 1 devices"):
_ = f_r(a) # A is all on the default device
# Replicate the input so that the execution knows
# that we are using multiple devices
a_replicated = jax.device_put(a, NamedSharding(mesh, None))
res_r = f_r(a_replicated)
self.assertAllClose(res_r, expected)
self.assertLen(res_r.addressable_shards, len(devices))
for i in range(len(devices)):
self.assertEqual(res_jax.addressable_shards[i].device,
res_r.addressable_shards[i].device)
self.assertEqual(res_jax.addressable_shards[i].index,
res_r.addressable_shards[i].index)
self.assertAllClose(res_jax.addressable_shards[i].data,
res_r.addressable_shards[i].data)
@jtu.parameterized_filterable(
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
kwargs=[