mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950
281 lines
8.2 KiB
Python
281 lines
8.2 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for --debug_nans."""
|
|
|
|
from absl.testing import absltest
|
|
|
|
import jax
|
|
import numpy as np
|
|
from unittest import SkipTest
|
|
|
|
from jax._src import api
|
|
from jax._src import test_util as jtu
|
|
from jax import numpy as jnp
|
|
from jax.experimental import pjit, maps
|
|
|
|
from jax import config
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
class DebugNaNsTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.cfg = config._read("jax_debug_nans")
|
|
config.update("jax_debug_nans", True)
|
|
|
|
def tearDown(self):
|
|
config.update("jax_debug_nans", self.cfg)
|
|
super().tearDown()
|
|
|
|
def testSinc(self):
|
|
# Regression test for #6936
|
|
self.assertEqual(jnp.sinc(0.0), 1.0)
|
|
|
|
def testSingleResultPrimitiveNoNaN(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans = jnp.tanh(A)
|
|
ans.block_until_ready()
|
|
|
|
def testMultipleResultPrimitiveNoNaN(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans, _ = jnp.linalg.eigh(A)
|
|
ans.block_until_ready()
|
|
|
|
def testJitComputationNoNaN(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans = jax.jit(jnp.tanh)(A)
|
|
ans.block_until_ready()
|
|
|
|
def testJitComputationNaN(self):
|
|
A = jnp.array(0.)
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = jax.jit(lambda x: 0. / x)(A)
|
|
ans.block_until_ready()
|
|
|
|
def testJitComputationNaNContextManager(self):
|
|
config.update("jax_debug_nans", False)
|
|
A = jnp.array(0.)
|
|
f = jax.jit(lambda x: 0. / x)
|
|
ans = f(A)
|
|
ans = f(A)
|
|
with self.assertRaises(FloatingPointError):
|
|
with jax.debug_nans(True):
|
|
ans = f(A)
|
|
ans.block_until_ready()
|
|
|
|
def testSingleResultPrimitiveNaN(self):
|
|
A = jnp.array(0.)
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = 0. / A
|
|
ans.block_until_ready()
|
|
|
|
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
|
|
def testCallDeoptimized(self, jit):
|
|
@jit
|
|
def f(x):
|
|
return jax.lax.cond(
|
|
x == 1, lambda _: np.nan, lambda _: 2., operand=None)
|
|
|
|
# This makes sure, when using the C++ jit, that the Python code has been
|
|
# run to compile, and the next call won't go through `cache_miss`.
|
|
f(2)
|
|
# 'cond' not 'xla_call'
|
|
msg = r"invalid value \(nan\) encountered in .*cond.*"
|
|
with self.assertRaisesRegex(FloatingPointError, msg):
|
|
f(1)
|
|
|
|
def testPmap(self):
|
|
pmap_funcs = [api._cpp_pmap]
|
|
|
|
for pmap in pmap_funcs:
|
|
f = pmap(lambda x: 0. / x)
|
|
# For the Cpp pmap, the first execution always goes through Python.
|
|
f(jnp.array([1.]))
|
|
|
|
with self.assertRaisesRegex(
|
|
FloatingPointError,
|
|
r"invalid value \(nan\) encountered in parallel computation"):
|
|
ans = f(jnp.array([0.]))
|
|
ans.block_until_ready()
|
|
|
|
if jax.device_count() >= 2:
|
|
with self.assertRaisesRegex(
|
|
FloatingPointError,
|
|
r"invalid value \(nan\) encountered in parallel computation"):
|
|
ans = f(jnp.array([1., 0.]))
|
|
ans.block_until_ready()
|
|
|
|
def testPmapNoNaN(self):
|
|
ans = jax.pmap(lambda x: 0. / x)(jnp.array([1.]))
|
|
ans.block_until_ready()
|
|
|
|
@jtu.ignore_warning(message=".*is an experimental.*")
|
|
def testXmap(self):
|
|
|
|
f = maps.xmap(
|
|
lambda x: 0. / x,
|
|
in_axes=["i"],
|
|
out_axes=["i"],
|
|
axis_resources={"i": "x"})
|
|
|
|
with jax.sharding.Mesh(np.array(jax.local_devices()[:1]), ('x',)):
|
|
with self.assertRaisesRegex(
|
|
FloatingPointError,
|
|
r"invalid value \(nan\) encountered in xmap"):
|
|
ans = f(jnp.array([0.]))
|
|
ans.block_until_ready()
|
|
|
|
if jax.device_count() >= 2:
|
|
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = f(jnp.array([1., 0.]))
|
|
ans.block_until_ready()
|
|
|
|
@jtu.ignore_warning(message=".*is an experimental.*")
|
|
def testPjit(self):
|
|
if jax.device_count() < 2:
|
|
raise SkipTest("test requires >=2 devices")
|
|
|
|
p = jax.sharding.PartitionSpec('x')
|
|
f = pjit.pjit(lambda x: 0. / x, in_shardings=p, out_shardings=p)
|
|
|
|
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = f(jnp.array([0., 1.]))
|
|
ans.block_until_ready()
|
|
|
|
def testDebugNansJitWithDonation(self):
|
|
# https://github.com/google/jax/issues/12514
|
|
a = jnp.array(0.)
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a)
|
|
ans.block_until_ready()
|
|
|
|
def testDebugNansPmapWithDonation(self):
|
|
a = jnp.zeros((1,))
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = jax.pmap(lambda x: 0. / x, donate_argnums=(0,))(a)
|
|
ans.block_until_ready()
|
|
|
|
@jtu.ignore_warning(message=".*is an experimental.*")
|
|
def testDebugNansPjitWithDonation(self):
|
|
if jax.device_count() < 2:
|
|
raise SkipTest("test requires >=2 devices")
|
|
|
|
p = jax.sharding.PartitionSpec('x')
|
|
f = pjit.pjit(lambda x: 0. / x,
|
|
in_shardings=p,
|
|
out_shardings=p,
|
|
donate_argnums=(0,))
|
|
|
|
with jax.sharding.Mesh(np.array(jax.local_devices()[:2]), ('x',)):
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = f(jnp.array([0., 1.]))
|
|
ans.block_until_ready()
|
|
|
|
def testDebugNansZeroDiv(self):
|
|
inp = jnp.zeros(())
|
|
def f(x, y):
|
|
return x / y
|
|
|
|
with self.assertRaisesRegex(
|
|
FloatingPointError,
|
|
r"invalid value \(nan\) encountered in jit\(true_divide\)"):
|
|
f(inp, inp)
|
|
|
|
# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
|
|
# Instead of `f` showing up in the error, the name should be of the
|
|
# primitive (true_divide) in this case.
|
|
with self.assertRaisesRegex(
|
|
FloatingPointError,
|
|
r"invalid value \(nan\) encountered in jit\(f\)"):
|
|
jax.jit(f)(inp, inp)
|
|
|
|
|
|
class DebugInfsTest(jtu.JaxTestCase):
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.cfg = config._read("jax_debug_infs")
|
|
config.update("jax_debug_infs", True)
|
|
|
|
def tearDown(self):
|
|
config.update("jax_debug_infs", self.cfg)
|
|
super().tearDown()
|
|
|
|
def testSingleResultPrimitiveNoInf(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans = jnp.tanh(A)
|
|
ans.block_until_ready()
|
|
|
|
def testMultipleResultPrimitiveNoInf(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans, _ = jnp.linalg.eigh(A)
|
|
ans.block_until_ready()
|
|
|
|
def testJitComputationNoInf(self):
|
|
A = jnp.array([[1., 2.], [2., 3.]])
|
|
ans = jax.jit(jnp.tanh)(A)
|
|
ans.block_until_ready()
|
|
|
|
def testSingleResultPrimitiveInf(self):
|
|
A = jnp.array(0.)
|
|
with self.assertRaises(FloatingPointError):
|
|
ans = 1. / A
|
|
ans.block_until_ready()
|
|
|
|
@jtu.sample_product(jit=jtu.JIT_IMPLEMENTATION)
|
|
def testCallDeoptimized(self, jit):
|
|
@jit
|
|
def f(x):
|
|
return jax.lax.cond(
|
|
x == 1, lambda _: np.inf, lambda _: 2., operand=None)
|
|
|
|
# This makes sure, when using the C++ jit, that the Python code has been
|
|
# run to compile, and the next call won't go through `cache_miss`.
|
|
f(2)
|
|
# 'cond' not 'xla_call'
|
|
msg = r"invalid value \(inf\) encountered in .*cond.*"
|
|
with self.assertRaisesRegex(FloatingPointError, msg):
|
|
f(1)
|
|
|
|
def testDebugNansDoesntCorruptCaches(self):
|
|
# https://github.com/google/jax/issues/6614
|
|
@jax.jit
|
|
def f(x):
|
|
return jnp.divide(x, x)
|
|
|
|
for _ in range(2):
|
|
try:
|
|
with jax.debug_nans(True):
|
|
jax.grad(f)(0.)
|
|
except FloatingPointError:
|
|
pass
|
|
|
|
def testDebugNansDoesntReturnDeoptimizedResult(self):
|
|
@jax.jit
|
|
def f(x):
|
|
y = x + 2 # avoid trivial dispatch path by adding some eqn
|
|
return jnp.nan, y
|
|
|
|
with self.assertRaisesRegex(FloatingPointError, "de-optimized"):
|
|
with jax.debug_nans(True):
|
|
f(3)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|