Fix pdot translation rule.

This concerns the direct pdot translation rule, which is not used
during spmd lowering.
This commit is contained in:
Anselm Levskaya 2021-01-13 12:52:28 -08:00
parent 2f3f37a770
commit 2ca247f43e
2 changed files with 27 additions and 5 deletions

View File

@ -828,13 +828,15 @@ batching.primitive_batchers[pdot_p] = _pdot_vmap_batching_rule
def _pdot_translation_rule(c, x, y, *, axis_name, pos_contract, pos_batch,
axis_env, platform):
assert axis_name
local_out = lax._dot_general_translation_rule(
c, x, y, dimension_numbers=[pos_contract, pos_batch], precision=None)
out_tup = xla.parallel_translations[psum_p](
c, local_out, axis_name=axis_name, axis_index_groups=None,
axis_env=axis_env, platform=platform)
out, = xla.xla_destructure(c, out_tup)
if axis_name:
out_tup = xla.parallel_translations[psum_p](
c, local_out, axis_name=axis_name, axis_index_groups=None,
axis_env=axis_env, platform=platform)
out, = xla.xla_destructure(c, out_tup)
else:
out = local_out
return out
xla.parallel_translations[pdot_p] = _pdot_translation_rule

View File

@ -283,6 +283,26 @@ class XMapTest(jtu.JaxTestCase):
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
@ignore_xmap_warning()
@with_mesh([('r1', 2)])
def testPdotBatchingShardUncontractedDim(self):
def f(x, y):
return lax.pdot(x, y, 'i')
rng = np.random.RandomState(0)
x = rng.randn(2, 3, 8)
y = rng.randn(2, 8, 5)
f_mapped = xmap(f,
in_axes=[{0: 'j', 2: 'i'}, {0: 'j', 1: 'i'}],
out_axes=['j', ...],
axis_resources={'j': 'r1'})
z = f_mapped(x, y)
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
class XMapErrorTest(jtu.JaxTestCase):
@ignore_xmap_warning()