mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Fix some test failures under H100.
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled. Happily, this also allows us to remove a number of TPU special cases for the same reason. PiperOrigin-RevId: 539101155
This commit is contained in:
parent
00f2a8c28c
commit
803c729b57
@ -49,6 +49,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
expected = 3 * np.ones(4)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testNestedBatchingMatMat(self):
|
||||
matvec = vmap(jnp.vdot, in_axes=(0, None))
|
||||
matmat = vmap(matvec, in_axes=(None, 1), out_axes=1)
|
||||
@ -59,9 +60,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
ans = matmat(A, B)
|
||||
expected = np.dot(A, B)
|
||||
self.assertAllClose(
|
||||
ans, expected, check_dtypes=False,
|
||||
rtol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
jaxpr = make_jaxpr(matmat)(A, B)
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 1)
|
||||
@ -98,6 +97,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
self.assertEqual(dW.shape, (batch_size,) + W.shape)
|
||||
self.assertEqual(db.shape, (batch_size,) + b.shape)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testJacobians(self):
|
||||
def jacbwd(f, x):
|
||||
y, pullback = vjp(f, x)
|
||||
@ -118,8 +118,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
f = lambda x: jnp.tanh(jnp.dot(A, x) + b)
|
||||
|
||||
x = R(3)
|
||||
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False,
|
||||
rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None)
|
||||
self.assertAllClose(jacfwd(f, x), jacbwd(f, x), check_dtypes=False)
|
||||
|
||||
def testBatchOfCompile(self):
|
||||
side = []
|
||||
@ -201,6 +200,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
expected_ans = x > 1.0
|
||||
self.assertAllClose(ans, expected_ans)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testNpMaximumPerExampleGrad(self):
|
||||
R = self.rng().randn
|
||||
x = R(10, 5)
|
||||
@ -218,9 +218,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
|
||||
expected_ans = jnp.transpose(expected_ans)
|
||||
|
||||
self.assertAllClose(
|
||||
ans[i], expected_ans, check_dtypes=False,
|
||||
rtol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
|
||||
self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
|
||||
|
||||
def testDotGeneral(self):
|
||||
R = self.rng().randn
|
||||
|
@ -291,6 +291,7 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
C = self.rng().rand(10, 10)
|
||||
np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C, optimize='greedy')
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_einsum_kpmurphy_example(self):
|
||||
# code from an email with @murphyk
|
||||
N, C, D, K, T = 2, 3, 4, 5, 6
|
||||
@ -309,9 +310,8 @@ class EinsumTest(jtu.JaxTestCase):
|
||||
L[n,c] = s
|
||||
|
||||
path = jnp.einsum_path('ntk,kd,dc->nc', S, W, V, optimize='optimal')[0]
|
||||
rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(L, jnp.einsum('ntk,kd,dc->nc', S, W, V, optimize=path),
|
||||
check_dtypes=False, rtol=rtol)
|
||||
check_dtypes=False)
|
||||
|
||||
def test_contraction_broadcasting(self):
|
||||
r = self.rng()
|
||||
|
@ -420,7 +420,7 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self.assertAlmostEqual(expected["a"], actual["a"], places=5)
|
||||
self.assertAlmostEqual(expected["b"], actual["b"], places=5)
|
||||
|
||||
@jtu.skip_on_devices('tpu')
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_gmres_matmul(self):
|
||||
A = CustomOperator(2 * jnp.eye(3))
|
||||
b = jnp.arange(9.0).reshape(3, 3)
|
||||
|
@ -1724,6 +1724,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(f([np.array([i] * ndevices) for i in range(500)]),
|
||||
jnp.array([sum(vals)] * ndevices))
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPostProcessMap2(self):
|
||||
# code from https://github.com/google/jax/issues/2787
|
||||
def vv(x, y):
|
||||
@ -1743,8 +1744,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = random.normal(key, (10, 50, 1))
|
||||
result = batched_mvm(y)
|
||||
expected = jnp.einsum('ij,njk->nik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else 1e-3
|
||||
self.assertAllClose(result, expected, check_dtypes=False, atol=tol, rtol=tol)
|
||||
self.assertAllClose(result, expected, check_dtypes=False, atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
|
@ -1348,6 +1348,7 @@ class PDotTests(XMapTestCase):
|
||||
self.assertAllClose(z, jnp.einsum('nij,njk->nik', x, y))
|
||||
|
||||
@jtu.with_mesh([('r1', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPdotBatchingShardUncontractedDim(self):
|
||||
def f(x, y):
|
||||
return lax.pdot(x, y, 'i')
|
||||
@ -1375,6 +1376,7 @@ class PDotTests(XMapTestCase):
|
||||
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
|
||||
pdot_spec, lhs_shape, rhs_shape))
|
||||
)))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPdotSystematic(self, lhs_shape, rhs_shape, pdot_spec, axis_resources,
|
||||
mesh_data):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1398,9 +1400,7 @@ class PDotTests(XMapTestCase):
|
||||
result = fun(lhs, rhs)
|
||||
|
||||
expected = lax.dot_general(lhs, rhs, pdot_spec.dot_general_dim_nums)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(result, expected, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(result, expected, check_dtypes=False)
|
||||
|
||||
@parameterized.named_parameters(jtu.named_cases_from_sampler(lambda s: ({
|
||||
"testcase_name": f"_{next(test_counter)}",
|
||||
@ -1412,6 +1412,7 @@ class PDotTests(XMapTestCase):
|
||||
for axis_resources, mesh_data in s(schedules_from_pdot_spec(
|
||||
pdot_spec, lhs_shape, rhs_shape))
|
||||
)))
|
||||
@jax.default_matmul_precision("float32")
|
||||
def testPdotVJPSystematic(self, lhs_shape, rhs_shape, pdot_spec,
|
||||
axis_resources, mesh_data):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
@ -1441,11 +1442,8 @@ class PDotTests(XMapTestCase):
|
||||
with jtu.with_mesh(mesh_data):
|
||||
lhs_bar, rhs_bar = fun(lhs, rhs, out_bar)
|
||||
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(lhs_bar, expected_lhs, check_dtypes=False)
|
||||
self.assertAllClose(rhs_bar, expected_rhs, check_dtypes=False)
|
||||
|
||||
def test_xeinsum_vector_dot(self):
|
||||
rng = self.rng()
|
||||
@ -1465,6 +1463,7 @@ class PDotTests(XMapTestCase):
|
||||
expected = np.einsum('i,j->ij', x, y)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_matmul(self):
|
||||
rng = self.rng()
|
||||
x = rng.randn(3, 4)
|
||||
@ -1475,9 +1474,7 @@ class PDotTests(XMapTestCase):
|
||||
in_axes=(['i', 'j'], ['j', 'k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,jk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
check('{i,j},{j,k}->{i,k}')
|
||||
check('{i,j},{k,j}->{k,i}') # order of named axes in the spec doesn't matter!
|
||||
check('{j},{k,j}->{k}')
|
||||
@ -1500,14 +1497,14 @@ class PDotTests(XMapTestCase):
|
||||
expected = np.einsum('ij,ij->i', x, y)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_no_named_axes_batch_matmul(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 5, 4)
|
||||
y = rng.randn(3, 4, 2)
|
||||
out = jnp.einsum('bij,bjk->bik', x, y, _use_xeinsum=True)
|
||||
expected = np.einsum('bij,bjk->bik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
def test_xeinsum_no_named_axes_reduce_sum(self):
|
||||
rng = self.rng()
|
||||
@ -1518,15 +1515,16 @@ class PDotTests(XMapTestCase):
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_no_named_axes_reduce_and_contract(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 5, 4)
|
||||
y = rng.randn(2, 4, 2)
|
||||
out = jnp.einsum('bij,cjk->ik', x, y, _use_xeinsum=True)
|
||||
expected = np.einsum('bij,cjk->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True, atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_named_axes_reduce(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(3, 4)
|
||||
@ -1537,12 +1535,11 @@ class PDotTests(XMapTestCase):
|
||||
in_axes=(['i', 'j'], ['k']),
|
||||
out_axes=['i', 'k'])(x, y)
|
||||
expected = np.einsum('ij,k->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
check('{i,j},{k}->{i,k}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_named_axes_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(6, 4)
|
||||
@ -1554,9 +1551,7 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=['i', 'k'],
|
||||
axis_resources={'i': 'x', 'k': 'y'})(x, y)
|
||||
expected = np.einsum('ij,k->ik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
check('{i,j},{k}->{i,k}')
|
||||
check('{i,j},{k}->{k,i}') # order of named axes in the spec doesn't matter!
|
||||
@ -1564,6 +1559,7 @@ class PDotTests(XMapTestCase):
|
||||
check('{j,i},{k}->{k,i}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_named_axes_batch_matmul_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 3, 4)
|
||||
@ -1575,14 +1571,13 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=['b', 'i', 'k'],
|
||||
axis_resources={'b': 'x', 'j': 'y'})(x, y)
|
||||
expected = np.einsum('bij,bjk->bik', x, y)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
check('{b,i,j},{b,j,k}->{b,i,k}')
|
||||
check('{j,i,b},{j,b,k}->{i,b,k}') # order of named axes in the spec doesn't matter!
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_named_axes_unary_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 6, 4)
|
||||
@ -1593,15 +1588,14 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=['b'],
|
||||
axis_resources={'b': 'x', 'i': 'y'})(x)
|
||||
expected = np.einsum('bij->b', x)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
check('{b,i,j}->{b}')
|
||||
check('{b,j,i}->{b}') # order of named axes in the spec doesn't matter!
|
||||
check('{i,j,b}->{b}')
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@jax.default_matmul_precision("float32")
|
||||
def test_xeinsum_mixed_axes_unary_reduce_with_mesh(self):
|
||||
rng = np.random.RandomState(0)
|
||||
x = rng.randn(8, 6, 4, 5)
|
||||
@ -1612,9 +1606,7 @@ class PDotTests(XMapTestCase):
|
||||
out_axes=['b', ...],
|
||||
axis_resources={'b': 'x', 'i': 'y'})(x)
|
||||
expected = np.einsum('bijk->bk', x)
|
||||
tol = 1e-1 if jtu.device_under_test() == "tpu" else None
|
||||
self.assertAllClose(out, expected, check_dtypes=True,
|
||||
atol=tol, rtol=tol)
|
||||
self.assertAllClose(out, expected, check_dtypes=True)
|
||||
|
||||
check('jk{i,b}->k{b}')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user