mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
This commit is contained in:
parent
82b0e0e0fb
commit
1acf9567aa
@ -2011,3 +2011,13 @@ def _match_replication(src, dst, x):
|
||||
if src - dst:
|
||||
x = pbroadcast(x, tuple(n for n in src if n not in dst))
|
||||
return x
|
||||
|
||||
# TODO(parkers,mattjj): change implementation when we have sharding-in-types.
|
||||
def get_replication(x: jax.Array) -> set[AxisName]:
|
||||
"""For a jax.Array, return what axes it is known to be replicated along."""
|
||||
|
||||
if isinstance(x, RewriteTracer):
|
||||
return x.rep
|
||||
if isinstance(x, batching.BatchTracer):
|
||||
return get_replication(x.val)
|
||||
raise ValueError("get_replication not defined on %s" % repr(type(x)))
|
||||
|
@ -2151,6 +2151,34 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
f(A()) # don't crash
|
||||
|
||||
def test_get_check_rep(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
||||
def f(x, reduce_along, use_jit):
|
||||
out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along))
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec)
|
||||
def g(x):
|
||||
result = lax.psum(x, axis_name=reduce_along)
|
||||
def check_rep(result):
|
||||
self.assertEqual(
|
||||
jax.experimental.shard_map.get_replication(result),
|
||||
set(reduce_along))
|
||||
return result
|
||||
result = check_rep(result)
|
||||
result = jax.vmap(check_rep)(result)
|
||||
return result
|
||||
if use_jit:
|
||||
return jax.jit(g)(x)
|
||||
else:
|
||||
return g(x)
|
||||
|
||||
for use_jit in [True, False]:
|
||||
x = np.zeros((8, 8), dtype=np.float32)
|
||||
f(x, reduce_along=('y',), use_jit=use_jit)
|
||||
f(x, reduce_along=('x',), use_jit=use_jit)
|
||||
f(x, reduce_along=('x', 'y'), use_jit=use_jit)
|
||||
|
||||
|
||||
class FunSpec(NamedTuple):
|
||||
name: str
|
||||
|
Loading…
x
Reference in New Issue
Block a user