From 1acf9567aae0742ced26475e2fe4ec3b551a16bd Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Fri, 20 Sep 2024 11:24:36 -0700 Subject: [PATCH] Add get_replication to shard_map.py for verifying if an array is replicated. PiperOrigin-RevId: 676910872 --- jax/experimental/shard_map.py | 10 ++++++++++ tests/shard_map_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f19401525..35d665943 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -2011,3 +2011,13 @@ def _match_replication(src, dst, x): if src - dst: x = pbroadcast(x, tuple(n for n in src if n not in dst)) return x + +# TODO(parkers,mattjj): change implementation when we have sharding-in-types. +def get_replication(x: jax.Array) -> set[AxisName]: + """For a jax.Array, return what axes it is known to be replicated along.""" + + if isinstance(x, RewriteTracer): + return x.rep + if isinstance(x, batching.BatchTracer): + return get_replication(x.val) + raise ValueError("get_replication not defined on %s" % repr(type(x))) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 20bc33475..fbe974651 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -2151,6 +2151,34 @@ class ShardMapTest(jtu.JaxTestCase): f(A()) # don't crash + def test_get_check_rep(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + def f(x, reduce_along, use_jit): + out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along)) + + @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) + def g(x): + result = lax.psum(x, axis_name=reduce_along) + def check_rep(result): + self.assertEqual( + jax.experimental.shard_map.get_replication(result), + set(reduce_along)) + return result + result = check_rep(result) + result = jax.vmap(check_rep)(result) + return result + if use_jit: + return jax.jit(g)(x) + else: + return g(x) + + for use_jit in [True, False]: + x = np.zeros((8, 8), dtype=np.float32) + f(x, reduce_along=('y',), use_jit=use_jit) + f(x, reduce_along=('x',), use_jit=use_jit) + f(x, reduce_along=('x', 'y'), use_jit=use_jit) + class FunSpec(NamedTuple): name: str