Rename count_jit_and_pmap_compiles to count_jit_and_pmap_lowerings

PiperOrigin-RevId: 661496993
This commit is contained in:
Yash Katariya 2024-08-09 20:03:06 -07:00 committed by jax authors
parent 7a75c96aa9
commit abc9ba00e9
7 changed files with 21 additions and 21 deletions

View File

@ -365,7 +365,7 @@ def count_aot_jit_cpp_cache_miss():
@contextmanager
def count_jit_and_pmap_compiles():
def count_jit_and_pmap_lowerings():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
@ -405,7 +405,7 @@ def count_subjaxpr_to_hlo_conversion(fun_name: str):
@contextmanager
def assert_num_jit_and_pmap_compilations(times):
with count_jit_and_pmap_compiles() as count:
with count_jit_and_pmap_lowerings() as count:
yield
if count[0] != times:
raise AssertionError(f"Expected exactly {times} XLA compilations, "

View File

@ -1442,7 +1442,7 @@ class JitTest(jtu.BufferDonationTestCase):
# https://github.com/google/jax/issues/9187
f = jax.jit(lambda: jnp.sin(1))
expected = f()
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = jax.vmap(f, axis_size=2, out_axes=None)()
self.assertEqual(count[0], 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)
@ -3433,11 +3433,11 @@ class APITest(jtu.JaxTestCase):
def f(x):
return jnp.sin(x)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
_ = jax.grad(f)(3.)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
_ = jax.grad(f)(3.)
_ = jax.grad(f)(4.)
self.assertEqual(count[0], 0) # cache hits on both fwd and bwd
@ -4352,7 +4352,7 @@ class APITest(jtu.JaxTestCase):
jf = jax.jit(f)
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(5):
jax.hessian(jf)(x).block_until_ready()
@ -5929,7 +5929,7 @@ class RematTest(jtu.JaxTestCase):
# https://github.com/google/jax/issues/9661
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_lin = jax.linearize(identity, 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
f_lin(1.).block_until_ready()
self.assertEqual(count[0], 1) # cached after first execution
@ -5947,7 +5947,7 @@ class RematTest(jtu.JaxTestCase):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(lambda x: identity(x, True), 1.)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
@ -5955,7 +5955,7 @@ class RematTest(jtu.JaxTestCase):
def test_fwd_caching(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
y, _ = jax.vjp(identity, 1.)
y.block_until_ready()
@ -5964,7 +5964,7 @@ class RematTest(jtu.JaxTestCase):
def test_fwd_caching_static_argnums(self):
# see above test also
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
y = identity(1.)
y.block_until_ready()

View File

@ -815,7 +815,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def test_retracing(self):
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
_ = f(3.)
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
_ = f(3.)
self.assertEqual(count[0], 0)

View File

@ -2590,7 +2590,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
def g(x):
return x + 2
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
for x in range(10):
lax.cond(x, f, g, x)
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond

View File

@ -1200,7 +1200,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
f = jax.jit(mul, in_shardings=s)
g = jax.jit(mul, in_shardings=s2)
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
out = f(np_inp)
out2 = g(np_inp2)
self.assertEqual(count[0], 1)

View File

@ -3523,7 +3523,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
arr = jax.device_put(
np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x')))
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns))
out = vf(arr)
self.assertIsInstance(out.sharding, NamedSharding)
@ -3867,7 +3867,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(out_a, NamedSharding(mesh2, P('x')))
f(b) # lowering cache *hit*
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
g(np.arange(8))
self.assertEqual(count[0], 1)
@ -3890,7 +3890,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # lowering cache *miss*
with jtu.count_jit_and_pmap_compiles() as count:
with jtu.count_jit_and_pmap_lowerings() as count:
g(np.arange(8))
self.assertEqual(count[0], 2)

View File

@ -1285,7 +1285,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
f = self.pmap(lambda x: 3)
x = jnp.arange(device_count)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
expected = np.repeat(3, device_count)
@ -1306,7 +1306,7 @@ class PythonPmapTest(jtu.JaxTestCase):
shuffle(devices)
f = self.pmap(lambda x: 3, devices=devices)
x = jnp.arange(len(devices))
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = np.repeat(3, len(devices))
@ -1342,7 +1342,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2, 3)
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
@ -1368,7 +1368,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
shape = (2, len(devices) // 2, 3)
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
@ -2039,7 +2039,7 @@ class PythonPmapTest(jtu.JaxTestCase):
_, f_bwd = jax.vjp(f, x)
_ = f_bwd(x)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)