mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19017 from gnecula:export_shard_map
PiperOrigin-RevId: 591614088
This commit is contained in:
commit
e0cc9879d5
@ -28,6 +28,7 @@ from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.experimental.export import export
|
||||
from jax.experimental.export import serialization
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.experimental import pjit
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import Mesh
|
||||
@ -804,6 +805,63 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
in_shardings=(jax.sharding.NamedSharding(mesh1, P("x", None)),)
|
||||
)(a)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(testcase_name=f"_poly={poly}", poly=poly)
|
||||
for poly in (None, "2*b1,_", "_,b2", "2*b1,b2")
|
||||
])
|
||||
def test_shard_map_collective_permute(self, poly=None):
|
||||
if len(jax.devices()) < 2:
|
||||
self.skipTest("Test requires at least 2 local devices")
|
||||
devices = np.array(jax.devices()[:2]) # use 2 devices
|
||||
mesh = Mesh(devices, axis_names=("x",))
|
||||
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
||||
|
||||
@functools.partial(
|
||||
pjit.pjit,
|
||||
in_shardings=NamedSharding(mesh, P("x", None),),
|
||||
out_shardings=NamedSharding(mesh, P("x", None)))
|
||||
@functools.partial(
|
||||
shard_map, mesh=mesh,
|
||||
in_specs=(P("x", None),), out_specs=P("x", None))
|
||||
def f_jax(b): # b: f32[2, 4]
|
||||
axis_size = lax.psum(1, "x")
|
||||
perm = [(j, (j + 1) % axis_size) for j in range(axis_size)]
|
||||
return lax.ppermute(b, "x", perm=perm)
|
||||
|
||||
args_specs = export.args_specs((a,), polymorphic_shapes=poly)
|
||||
exp = get_exported(f_jax)(*args_specs)
|
||||
|
||||
# Test JAX native execution
|
||||
res_jax = f_jax(a)
|
||||
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
|
||||
b0, b1 = b1, b0
|
||||
expected = np.concatenate([b0, b1], axis=0) # out_specs concatenates on axis 0
|
||||
self.assertAllClose(res_jax, expected)
|
||||
self.assertLen(res_jax.addressable_shards, len(devices))
|
||||
|
||||
# Test reloaded execution.
|
||||
f_r = export.call_exported(exp)
|
||||
with self.assertRaisesRegex(
|
||||
Exception,
|
||||
"Exported module .* was lowered for 2 devices and is "
|
||||
"called in a context with 1 devices"):
|
||||
_ = f_r(a) # A is all on the default device
|
||||
|
||||
# Replicate the input so that the execution knows
|
||||
# that we are using multiple devices
|
||||
a_replicated = jax.device_put(a, NamedSharding(mesh, None))
|
||||
res_r = f_r(a_replicated)
|
||||
self.assertAllClose(res_r, expected)
|
||||
self.assertLen(res_r.addressable_shards, len(devices))
|
||||
for i in range(len(devices)):
|
||||
self.assertEqual(res_jax.addressable_shards[i].device,
|
||||
res_r.addressable_shards[i].device)
|
||||
self.assertEqual(res_jax.addressable_shards[i].index,
|
||||
res_r.addressable_shards[i].index)
|
||||
self.assertAllClose(res_jax.addressable_shards[i].data,
|
||||
res_r.addressable_shards[i].data)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
one_containing="in_shardings_None_out_shardings_P_with_mesh_False",
|
||||
kwargs=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user