[x64] make xmap_test compatible with strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-06-17 16:35:42 -07:00
parent 5236140b9f
commit c467b15f06

View File

@ -300,7 +300,7 @@ class XMapTest(XMapTestCase):
self.assertAllClose(result, perm)
def testCollectiveAllGather(self):
x = jnp.arange(4)
x = jnp.arange(4, dtype='int32')
result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
in_axes=['i', ...], out_axes=['i', ...])(x)
self.assertAllClose(result, x[jnp.newaxis] + x[jnp.newaxis].T)
@ -381,13 +381,13 @@ class XMapTest(XMapTestCase):
axis_resources=dict([axis_resources[1]]))
def h(y):
# Multiply by a constant array to better exercise the partial_eval rule
return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b'))
return jnp.sin(y) * np.arange(y.size, dtype=float), lax.psum(y, ('a', 'b'))
return h(y)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
y = f(x)
self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])[None, None]).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1], dtype=float)[None, None]).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
self.assertEqual(y[0].sharding_spec.sharding,
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
self.assertEqual(y[0].sharding_spec.mesh_mapping,
@ -697,14 +697,15 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
def testNestedMeshSPMD(self):
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size, dtype=float),
lax.psum(y, ('a', 'b', 'c'))),
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
axis_resources={'c': 'z'})
f = xmap(lambda x: h(x * 2),
in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
axis_resources={'a': 'x', 'b': 'y'})
xshape = (8, 2, 4, 5)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
self.assertIsNot(match, None)
@ -838,23 +839,23 @@ class NamedRandomTest(XMapTestCase):
class NamedNNTest(XMapTestCase):
def testOneHot(self):
f = xmap(lambda x: jax.nn.one_hot([1, 2, 0], 3, axis='i'),
f = xmap(lambda x: jax.nn.one_hot(jnp.array([1, 2, 0], dtype='int32'), 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
expected = jnp.array([[0., 1., 0.],
[0., 0., 1.],
[1., 0., 0.]]).T
self.assertAllClose(f(jnp.ones((3,))), expected)
[0., 0., 1.],
[1., 0., 0.]]).T
self.assertAllClose(f(jnp.ones(3, dtype='int32')), expected)
def testOneHotOutOfBound(self):
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
f = xmap(lambda x: jax.nn.one_hot(jnp.array([-1, 3], dtype='int32'), 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
self.assertAllClose(f(jnp.ones((3,))), jnp.zeros((3, 2)))
self.assertAllClose(f(jnp.ones(3, dtype='int32')), jnp.zeros((3, 2)))
def testOneHotAxisSizeMismatch(self):
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
f = xmap(lambda x: jax.nn.one_hot(jnp.array([-1, 3], dtype='int32'), 3, axis='i'),
in_axes=['i', ...], out_axes=['i', ...])
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
f(jnp.ones((5,)))
f(jnp.ones(5, dtype='int32'))
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
@ -1712,7 +1713,7 @@ class XMapErrorTest(jtu.JaxTestCase):
f = xmap(lambda x: lax.axis_index('i') + x,
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
x = np.arange(4)
x = np.arange(4, dtype='int32')
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
r"coincide in the named_shape of a value returned from a primitive "
r"add created at .*")