rocm_jax/tests/debug_nans_test.py
Yash Katariya e624610e72 Replace apply_primitive internals with jax.jit.
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
2023-11-29 18:07:13 -08:00

281 lines
8.2 KiB
Python

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for --debug_nans."""
from absl.testing import absltest
import jax
import numpy as np
from unittest import SkipTest
from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit, maps
from jax import config
config.parse_flags_with_absl()
class DebugNaNsTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)
def tearDown(self):
config.update("jax_debug_nans", self.cfg)
super().tearDown()
def testSinc(self):
# Regression test for #6936
self.assertEqual(jnp.sinc(0.0), 1.0)
def testSingleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testJitComputationNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x)(A)
ans.block_until_ready()
def testJitComputationNaNContextManager(self):
config.update("jax_debug_nans", False)
A = jnp.array(0.)
f = jax.jit(lambda x: 0. / x)
ans = f(A)
ans = f(A)
with self.assertRaises(FloatingPointError):
with jax.debug_nans(True):
ans = f(A)
ans.block_until_ready()
def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 0. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.nan, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(nan\) encountered in .*cond.*"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testPmap(self):
pmap_funcs = [api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testXmap(self):
f = maps.xmap(
lambda x: 0. / x,
in_axes=["i"],
out_axes=["i"],
axis_resources={"i": "x"})
with jax.sharding.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in xmap"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x, in_shardings=p, out_shardings=p)
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansJitWithDonation(self):
# https://github.com/google/jax/issues/12514
a = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
def testDebugNansPmapWithDonation(self):
a = jnp.zeros((1,))
with self.assertRaises(FloatingPointError):
ans = jax.pmap(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testDebugNansPjitWithDonation(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_shardings=p,
out_shardings=p,
donate_argnums=(0,))
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansZeroDiv(self):
inp = jnp.zeros(())
def f(x, y):
return x / y
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(true_divide\)"):
f(inp, inp)
# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
# Instead of `f` showing up in the error, the name should be of the
# primitive (true_divide) in this case.
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(f\)"):
jax.jit(f)(inp, inp)
class DebugInfsTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)
def tearDown(self):
config.update("jax_debug_infs", self.cfg)
super().tearDown()
def testSingleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testSingleResultPrimitiveInf(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 1. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.inf, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(inf\) encountered in .*cond.*"
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testDebugNansDoesntCorruptCaches(self):
# https://github.com/google/jax/issues/6614
@jax.jit
def f(x):
return jnp.divide(x, x)
for _ in range(2):
try:
with jax.debug_nans(True):
jax.grad(f)(0.)
except FloatingPointError:
pass
def testDebugNansDoesntReturnDeoptimizedResult(self):
@jax.jit
def f(x):
y = x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan, y
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
with jax.debug_nans(True):
f(3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())