rocm_jax/tests/debug_nans_test.py
Sergei Lebedev 0ff234049b Removed trivial docstrings from JAX tests
These docstrings do not make the tests any more clear and typically just duplicate the test module name.

PiperOrigin-RevId: 737611977
2025-03-17 07:49:37 -07:00

296 lines
8.6 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import numpy as np
from unittest import SkipTest
from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
jax.config.parse_flags_with_absl()
@jtu.with_config(jax_debug_nans=True)
class DebugNaNsTest(jtu.JaxTestCase):
def testSinc(self):
# Regression test for #6936
self.assertEqual(jnp.sinc(0.0), 1.0)
def testSingleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testJitComputationNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x)(A)
ans.block_until_ready()
@jax.debug_nans(False)
def testJitComputationNaNContextManager(self):
A = jnp.array(0.)
f = jax.jit(lambda x: 0. / x)
ans = f(A)
ans = f(A)
with self.assertRaises(FloatingPointError):
with jax.debug_nans(True):
ans = f(A)
ans.block_until_ready()
def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 0. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.nan, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(nan\) encountered in .*cond.*"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testShardMap(self):
mesh = jax.make_mesh((1,), ('x',))
f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmap(self):
pmap_funcs = [api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testGradPmap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
_, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testGradShardMap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
mesh = jax.make_mesh((1,), ('x',))
shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError, r"Invalid value \(nan\) encountered"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x, in_shardings=p, out_shardings=p)
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansJitWithDonation(self):
# https://github.com/jax-ml/jax/issues/12514
a = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
def testDebugNansPmapWithDonation(self):
a = jnp.zeros((1,))
with self.assertRaises(FloatingPointError):
ans = jax.pmap(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testDebugNansPjitWithDonation(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_shardings=p,
out_shardings=p,
donate_argnums=(0,))
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansZeroDiv(self):
inp = jnp.zeros(())
def f(x, y):
return x / y
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
f(inp, inp)
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
jax.jit(f)(inp, inp)
def testDebugNansInput(self):
@jax.jit
def f(x):
return x * 3.
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"):
f(np.nan)
@jtu.with_config(jax_debug_infs=True)
class DebugInfsTest(jtu.JaxTestCase):
def testSingleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testSingleResultPrimitiveInf(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 1. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.inf, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(inf\) encountered in .*cond.*"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testDebugNansDoesntCorruptCaches(self):
# https://github.com/jax-ml/jax/issues/6614
@jax.jit
def f(x):
return jnp.divide(x, x)
for _ in range(2):
try:
with jax.debug_nans(True):
jax.grad(f)(0.)
except FloatingPointError:
pass
def testDebugNansDoesntReturnDeoptimizedResult(self):
@jax.jit
def f(x):
y = x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan, y
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"):
with jax.debug_nans(True):
f(3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())