mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix pdot translation rule.
This concerns the direct pdot translation rule, which is not used during spmd lowering.
This commit is contained in:
parent
2f3f37a770
commit
2ca247f43e
@ -828,13 +828,15 @@ batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule
|
||||
|
||||
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
|
||||
axis_env, platform):
|
||||
assert axis_name
|
||||
local_out = lax._dot_general_translation_rule(
|
||||
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None)
|
||||
out_tup = xla.parallel_translations[psum_p](
|
||||
c, local_out, axis_name=axis_name, axis_index_groups=None,
|
||||
axis_env=axis_env, platform=platform)
|
||||
out, = xla.xla_destructure(c, out_tup)
|
||||
if axis_name:
|
||||
out_tup = xla.parallel_translations[psum_p](
|
||||
c, local_out, axis_name=axis_name, axis_index_groups=None,
|
||||
axis_env=axis_env, platform=platform)
|
||||
out, = xla.xla_destructure(c, out_tup)
|
||||
else:
|
||||
out = local_out
|
||||
return out
|
||||
xla.parallel_translations[pdot_p] = _pdot_translation_rule
|
||||
|
||||
|
@ -283,6 +283,26 @@ class XMapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||
|
||||
@ignore_xmap_warning()
|
||||
@with_mesh([('r1', 2)])
|
||||
def testPdotBatchingShardUncontractedDim(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(2, 3, 8)
|
||||
y = rng.randn(2, 8, 5)
|
||||
|
||||
f_mapped = xmap(f,
|
||||
in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
|
||||
out_axes=['j', ...],
|
||||
axis_resources={'j': 'r1'})
|
||||
|
||||
z = f_mapped(x, y)
|
||||
|
||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||
|
||||
|
||||
class XMapErrorTest(jtu.JaxTestCase):
|
||||
|
||||
@ignore_xmap_warning()
|
||||
|
Loading…
x
Reference in New Issue
Block a user