mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Rename count_jit_and_pmap_compiles
to count_jit_and_pmap_lowerings
PiperOrigin-RevId: 661496993
This commit is contained in:
parent
7a75c96aa9
commit
abc9ba00e9
@ -365,7 +365,7 @@ def count_aot_jit_cpp_cache_miss():
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_jit_and_pmap_compiles():
|
||||
def count_jit_and_pmap_lowerings():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
@ -405,7 +405,7 @@ def count_subjaxpr_to_hlo_conversion(fun_name: str):
|
||||
|
||||
@contextmanager
|
||||
def assert_num_jit_and_pmap_compilations(times):
|
||||
with count_jit_and_pmap_compiles() as count:
|
||||
with count_jit_and_pmap_lowerings() as count:
|
||||
yield
|
||||
if count[0] != times:
|
||||
raise AssertionError(f"Expected exactly {times} XLA compilations, "
|
||||
|
@ -1442,7 +1442,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
# https://github.com/google/jax/issues/9187
|
||||
f = jax.jit(lambda: jnp.sin(1))
|
||||
expected = f()
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = jax.vmap(f, axis_size=2, out_axes=None)()
|
||||
self.assertEqual(count[0], 0) # no compiles
|
||||
self.assertArraysAllClose(ans, expected, check_dtypes=True)
|
||||
@ -3433,11 +3433,11 @@ class APITest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
return jnp.sin(x)
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
_ = jax.grad(f)(3.)
|
||||
self.assertEqual(count[0], 2) # one for fwd, one for bwd
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
_ = jax.grad(f)(3.)
|
||||
_ = jax.grad(f)(4.)
|
||||
self.assertEqual(count[0], 0) # cache hits on both fwd and bwd
|
||||
@ -4352,7 +4352,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jf = jax.jit(f)
|
||||
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(5):
|
||||
jax.hessian(jf)(x).block_until_ready()
|
||||
|
||||
@ -5929,7 +5929,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
# https://github.com/google/jax/issues/9661
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
|
||||
_, f_lin = jax.linearize(identity, 1.)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_lin(1.).block_until_ready()
|
||||
self.assertEqual(count[0], 1) # cached after first execution
|
||||
@ -5947,7 +5947,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
|
||||
static_argnums=(1,))
|
||||
_, f_vjp = jax.vjp(lambda x: identity(x, True), 1.)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
|
||||
@ -5955,7 +5955,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
y, _ = jax.vjp(identity, 1.)
|
||||
y.block_until_ready()
|
||||
@ -5964,7 +5964,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
def test_fwd_caching_static_argnums(self):
|
||||
# see above test also
|
||||
identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,))
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
y = identity(1.)
|
||||
y.block_until_ready()
|
||||
|
@ -815,7 +815,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def test_retracing(self):
|
||||
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
|
||||
_ = f(3.)
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
_ = f(3.)
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
|
@ -2590,7 +2590,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
return x + 2
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
for x in range(10):
|
||||
lax.cond(x, f, g, x)
|
||||
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond
|
||||
|
@ -1200,7 +1200,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
f = jax.jit(mul, in_shardings=s)
|
||||
g = jax.jit(mul, in_shardings=s2)
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
out = f(np_inp)
|
||||
out2 = g(np_inp2)
|
||||
self.assertEqual(count[0], 1)
|
||||
|
@ -3523,7 +3523,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(
|
||||
np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns))
|
||||
out = vf(arr)
|
||||
self.assertIsInstance(out.sharding, NamedSharding)
|
||||
@ -3867,7 +3867,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P('x')))
|
||||
f(b) # lowering cache *hit*
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
@ -3890,7 +3890,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # lowering cache *miss*
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 2)
|
||||
|
||||
|
@ -1285,7 +1285,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
device_count = jax.device_count()
|
||||
f = self.pmap(lambda x: 3)
|
||||
x = jnp.arange(device_count)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
|
||||
expected = np.repeat(3, device_count)
|
||||
@ -1306,7 +1306,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
shuffle(devices)
|
||||
f = self.pmap(lambda x: 3, devices=devices)
|
||||
x = jnp.arange(len(devices))
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
expected = np.repeat(3, len(devices))
|
||||
@ -1342,7 +1342,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2, 3)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
expected = 3 * np.ones(shape[:2])
|
||||
@ -1368,7 +1368,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
|
||||
shape = (2, len(devices) // 2, 3)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
expected = 3 * np.ones(shape[:2])
|
||||
@ -2039,7 +2039,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
_, f_bwd = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
_, f_bwd2 = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
_ = f_bwd2(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user