mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[x64] make xmap_test compatible with strict dtype promotion
This commit is contained in:
parent
5236140b9f
commit
c467b15f06
@ -300,7 +300,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertAllClose(result, perm)
|
||||
|
||||
def testCollectiveAllGather(self):
|
||||
x = jnp.arange(4)
|
||||
x = jnp.arange(4, dtype='int32')
|
||||
result = xmap(lambda x: lax.all_gather(x, 'i') + lax.axis_index('i'),
|
||||
in_axes=['i', ...], out_axes=['i', ...])(x)
|
||||
self.assertAllClose(result, x[jnp.newaxis] + x[jnp.newaxis].T)
|
||||
@ -381,13 +381,13 @@ class XMapTest(XMapTestCase):
|
||||
axis_resources=dict([axis_resources[1]]))
|
||||
def h(y):
|
||||
# Multiply by a constant array to better exercise the partial_eval rule
|
||||
return jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b'))
|
||||
return jnp.sin(y) * np.arange(y.size, dtype=float), lax.psum(y, ('a', 'b'))
|
||||
return h(y)
|
||||
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1])[None, None]).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
self.assertAllClose(y, ((jnp.sin(x * 2) * np.arange(xshape[-1], dtype=float)[None, None]).transpose((1, 2, 0)), (x * 2).sum((0, 1))))
|
||||
self.assertEqual(y[0].sharding_spec.sharding,
|
||||
(pxla.Chunked([2]), pxla.NoSharding(), pxla.NoSharding()))
|
||||
self.assertEqual(y[0].sharding_spec.mesh_mapping,
|
||||
@ -697,14 +697,15 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2), ('z', 2)])
|
||||
def testNestedMeshSPMD(self):
|
||||
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size), lax.psum(y, ('a', 'b', 'c'))),
|
||||
h = xmap(lambda y: (jnp.sin(y) * np.arange(y.size, dtype=float),
|
||||
lax.psum(y, ('a', 'b', 'c'))),
|
||||
in_axes={0: 'c'}, out_axes=({1: 'c'}, {}),
|
||||
axis_resources={'c': 'z'})
|
||||
f = xmap(lambda x: h(x * 2),
|
||||
in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
|
||||
axis_resources={'a': 'x', 'b': 'y'})
|
||||
xshape = (8, 2, 4, 5)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
|
||||
self.assertIsNot(match, None)
|
||||
@ -838,23 +839,23 @@ class NamedRandomTest(XMapTestCase):
|
||||
class NamedNNTest(XMapTestCase):
|
||||
|
||||
def testOneHot(self):
|
||||
f = xmap(lambda x: jax.nn.one_hot([1, 2, 0], 3, axis='i'),
|
||||
f = xmap(lambda x: jax.nn.one_hot(jnp.array([1, 2, 0], dtype='int32'), 3, axis='i'),
|
||||
in_axes=['i', ...], out_axes=['i', ...])
|
||||
expected = jnp.array([[0., 1., 0.],
|
||||
[0., 0., 1.],
|
||||
[1., 0., 0.]]).T
|
||||
self.assertAllClose(f(jnp.ones((3,))), expected)
|
||||
[0., 0., 1.],
|
||||
[1., 0., 0.]]).T
|
||||
self.assertAllClose(f(jnp.ones(3, dtype='int32')), expected)
|
||||
|
||||
def testOneHotOutOfBound(self):
|
||||
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
|
||||
f = xmap(lambda x: jax.nn.one_hot(jnp.array([-1, 3], dtype='int32'), 3, axis='i'),
|
||||
in_axes=['i', ...], out_axes=['i', ...])
|
||||
self.assertAllClose(f(jnp.ones((3,))), jnp.zeros((3, 2)))
|
||||
self.assertAllClose(f(jnp.ones(3, dtype='int32')), jnp.zeros((3, 2)))
|
||||
|
||||
def testOneHotAxisSizeMismatch(self):
|
||||
f = xmap(lambda x: jax.nn.one_hot([-1, 3], 3, axis='i'),
|
||||
f = xmap(lambda x: jax.nn.one_hot(jnp.array([-1, 3], dtype='int32'), 3, axis='i'),
|
||||
in_axes=['i', ...], out_axes=['i', ...])
|
||||
with self.assertRaisesRegex(ValueError, "to match the size of axis i, but 3 != 5"):
|
||||
f(jnp.ones((5,)))
|
||||
f(jnp.ones(5, dtype='int32'))
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"_map_in={map_in}_map_out={map_out}_fan={fan}_distr={distr}",
|
||||
@ -1712,7 +1713,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
f = xmap(lambda x: lax.axis_index('i') + x,
|
||||
in_axes=[], out_axes=['i'], axis_sizes={'i': 4}, axis_resources={'i': 'x'})
|
||||
h = xmap(f, in_axes=['j', ...], out_axes=['j', ...], axis_resources={'j': 'x'})
|
||||
x = np.arange(4)
|
||||
x = np.arange(4, dtype='int32')
|
||||
error = (r"Axes `i` and `j` are both mapped to the resource `x`, but they "
|
||||
r"coincide in the named_shape of a value returned from a primitive "
|
||||
r"add created at .*")
|
||||
|
Loading…
x
Reference in New Issue
Block a user