1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00
rocm_jax/tests/debug_nans_test.py

296 lines
8.6 KiB
Python
Raw Permalink Normal View History

# Copyright 2019 The JAX Authors.
2019-10-02 15:02:15 -07:00
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
2019-10-02 15:02:15 -07:00
import jax
import numpy as np
from unittest import SkipTest
from jax._src import api
from jax._src import test_util as jtu
from jax import numpy as jnp
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
2019-10-02 15:02:15 -07:00
jax.config.parse_flags_with_absl()
2019-10-02 15:02:15 -07:00
2022-05-06 16:28:24 +01:00
@jtu.with_config(jax_debug_nans=True)
2019-10-02 15:02:15 -07:00
class DebugNaNsTest(jtu.JaxTestCase):
2019-10-02 15:55:09 -07:00
2021-06-10 09:14:07 -07:00
def testSinc(self):
# Regression test for #6936
self.assertEqual(jnp.sinc(0.0), 1.0)
def testSingleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
2020-10-08 13:00:32 -07:00
ans = jnp.tanh(A)
ans.block_until_ready()
2019-10-02 15:55:09 -07:00
def testMultipleResultPrimitiveNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
2020-10-08 13:00:32 -07:00
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
2019-10-02 15:55:09 -07:00
def testJitComputationNoNaN(self):
A = jnp.array([[1., 2.], [2., 3.]])
2020-10-08 13:00:32 -07:00
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testJitComputationNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x)(A)
ans.block_until_ready()
@jax.debug_nans(False)
def testJitComputationNaNContextManager(self):
A = jnp.array(0.)
f = jax.jit(lambda x: 0. / x)
ans = f(A)
ans = f(A)
with self.assertRaises(FloatingPointError):
with jax.debug_nans(True):
ans = f(A)
ans.block_until_ready()
def testSingleResultPrimitiveNaN(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
2020-10-08 13:00:32 -07:00
ans = 0. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
2022-05-06 16:28:24 +01:00
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.nan, lambda _: 2., operand=None)
2022-05-06 16:28:24 +01:00
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(nan\) encountered in .*cond.*"
2022-05-06 16:28:24 +01:00
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
def testShardMap(self):
mesh = jax.make_mesh((1,), ('x',))
f = shard_map(lambda x: 0. / x, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in sharded computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testPmap(self):
pmap_funcs = [api._cpp_pmap]
for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)
# For the Cpp pmap, the first execution always goes through Python.
f(jnp.array([1.]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
ans = f(jnp.array([0.]))
ans.block_until_ready()
if jax.device_count() >= 2:
with self.assertRaisesRegex(
FloatingPointError,
r"Invalid value \(nan\) encountered in parallel computation"):
ans = f(jnp.array([1., 0.]))
ans.block_until_ready()
def testGradPmap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
_, f_vjp = jax.vjp(jax.pmap(f), jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError,
r"invalid value \(nan\) encountered in mul\nWhen differentiating"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testGradShardMap(self):
@jax.jit
def f(x):
y = x**2
return jnp.log(y)
mesh = jax.make_mesh((1,), ('x',))
shmap_f = shard_map(f, mesh=mesh, in_specs=(P('x')), out_specs=P('x'))
_, f_vjp = jax.vjp(shmap_f, jnp.zeros([1]))
with self.assertRaisesRegex(
FloatingPointError, r"Invalid value \(nan\) encountered"):
ans, = f_vjp(jnp.ones([1]))
ans.block_until_ready()
def testPmapNoNaN(self):
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testPjit(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x, in_shardings=p, out_shardings=p)
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansJitWithDonation(self):
# https://github.com/jax-ml/jax/issues/12514
a = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
def testDebugNansPmapWithDonation(self):
a = jnp.zeros((1,))
with self.assertRaises(FloatingPointError):
ans = jax.pmap(lambda x: 0. / x, donate_argnums=(0,))(a)
ans.block_until_ready()
@jtu.ignore_warning(message=".*is an experimental.*")
def testDebugNansPjitWithDonation(self):
if jax.device_count() < 2:
raise SkipTest("test requires >=2 devices")
p = jax.sharding.PartitionSpec('x')
f = pjit.pjit(lambda x: 0. / x,
in_shardings=p,
out_shardings=p,
donate_argnums=(0,))
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
with self.assertRaises(FloatingPointError):
ans = f(jnp.array([0., 1.]))
ans.block_until_ready()
def testDebugNansZeroDiv(self):
inp = jnp.zeros(())
def f(x, y):
return x / y
with self.assertRaisesRegex(
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
f(inp, inp)
with self.assertRaisesRegex(
Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
2023-11-29 18:06:36 -08:00
FloatingPointError,
r"invalid value \(nan\) encountered in div"):
jax.jit(f)(inp, inp)
def testDebugNansInput(self):
@jax.jit
def f(x):
return x * 3.
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*input"):
f(np.nan)
@jtu.with_config(jax_debug_infs=True)
2021-01-06 14:43:05 +00:00
class DebugInfsTest(jtu.JaxTestCase):
def testSingleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jnp.tanh(A)
ans.block_until_ready()
def testMultipleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans, _ = jnp.linalg.eigh(A)
ans.block_until_ready()
def testJitComputationNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
ans = jax.jit(jnp.tanh)(A)
ans.block_until_ready()
def testSingleResultPrimitiveInf(self):
A = jnp.array(0.)
with self.assertRaises(FloatingPointError):
ans = 1. / A
ans.block_until_ready()
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
2022-05-06 16:28:24 +01:00
def testCallDeoptimized(self, jit):
@jit
def f(x):
return jax.lax.cond(
x == 1, lambda _: np.inf, lambda _: 2., operand=None)
# This makes sure, when using the C++ jit, that the Python code has been
# run to compile, and the next call won't go through `cache_miss`.
f(2)
# 'cond' not 'xla_call'
msg = r"invalid value \(inf\) encountered in .*cond.*"
2022-05-06 16:28:24 +01:00
with self.assertRaisesRegex(FloatingPointError, msg):
f(1)
2021-01-06 14:43:05 +00:00
def testDebugNansDoesntCorruptCaches(self):
# https://github.com/jax-ml/jax/issues/6614
@jax.jit
def f(x):
return jnp.divide(x, x)
for _ in range(2):
try:
with jax.debug_nans(True):
jax.grad(f)(0.)
except FloatingPointError:
pass
def testDebugNansDoesntReturnDeoptimizedResult(self):
@jax.jit
def f(x):
y = x + 2 # avoid trivial dispatch path by adding some eqn
return jnp.nan, y
with self.assertRaisesRegex(FloatingPointError, "the de-optimized function did not .*literal"):
with jax.debug_nans(True):
f(3)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())