Don't monkey-patch functions in test_utils to count events for tests.

This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
This commit is contained in:
Peter Hawkins 2024-12-11 16:54:52 -05:00
parent 3630756e87
commit 62e66b684b
25 changed files with 216 additions and 324 deletions

View File

@ -37,10 +37,10 @@ class AttrsTests(jtu.JaxTestCase):
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
self.assertEqual(count(), 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertEqual(count(), 0) # cache hit
def test_array_attr_no_jit(self):
with jax.disable_jit():

View File

@ -33,6 +33,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
util.test_event("_array_shard_arg")
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
batch_cs = []
@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
util.test_event("batched_copy_array")
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
for i, copy_out in safe_zip(batch_indices, copy_outs):

View File

@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):
@util.cache()
def xla_primitive_callable(prim: core.Primitive, **params):
util.test_event("xla_primitive_callable_cache_miss")
def prim_fun(*args):
with config.eager_constant_folding(False):
return prim.bind(*args, **params)

View File

@ -1085,6 +1085,7 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
util.test_event("lower_jaxpr_to_module")
platforms = tuple(map(xb.canonicalize_platform, platforms))
in_avals = (jaxpr.in_avals if arg_shardings is None else
@ -1378,6 +1379,7 @@ def lower_jaxpr_to_fun(
Returns:
MLIR func op
"""
util.test_event("lower_jaxpr_to_fun", name)
# The first dimension variable may be the platform index
num_dim_vars = len(ctx.shape_poly_state.dim_vars)

View File

@ -231,16 +231,20 @@ shard_arg_handlers[core.MutableArray] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
sharding: JSharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
util.test_event("batched_device_put_start")
try:
from jax._src import array
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
finally:
util.test_event("batched_device_put_end")
def _shard_aval(size, axis: int, aval):
try:
@ -2850,6 +2854,7 @@ class UnloadedMeshExecutable:
mesh = i.mesh
break
util.test_event("pxla_cached_compilation")
xla_executable = _cached_compilation(
hlo, name, mesh, spmd_lowering,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,

View File

@ -549,6 +549,7 @@ def _infer_params_impl(
kwargs: dict[str, Any],
in_avals: tuple[core.AbstractValue, ...] | None,
) -> tuple[PjitParams, list[Any]]:
util.test_event("pjit._infer_params_impl", fun)
have_kwargs = bool(kwargs)
if have_kwargs and ji.user_specified_in_shardings:
raise ValueError(
@ -1297,6 +1298,7 @@ def _create_pjit_jaxpr(
ignored_inline: IgnoreKey
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
util.test_event("create_pjit_jaxpr")
del ignored_inline # just for explain_cache_miss
if config.no_tracing.value:
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
@ -1784,6 +1786,7 @@ def _pjit_lower(
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if config.sharding_in_types.value:
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
else:

View File

@ -533,6 +533,7 @@ class Compiled(Stage):
@staticmethod
def call(*args, **kwargs):
util.test_event("stages_compiled_call")
# This is because `__call__` passes in `self._params` as the first argument.
# Instead of making the call signature `call(params, *args, **kwargs)`
# extract it from args because `params` can be passed as a kwarg by users

View File

@ -30,6 +30,7 @@ import re
import sys
import tempfile
import textwrap
import threading
from typing import Any, TextIO
import unittest
import warnings
@ -40,22 +41,17 @@ from absl.testing import parameterized
import jax
from jax import lax
from jax._src import api
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src import lib as _jaxlib
from jax._src import linear_util as lu
from jax._src import monitoring
from jax._src import pjit as pjit_lib
from jax._src import stages
from jax._src import xla_bridge
from jax._src import util
from jax._src import mesh as mesh_lib
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
@ -235,32 +231,68 @@ capture_stdout = partial(_capture_output, sys.stdout)
capture_stderr = partial(_capture_output, sys.stderr)
@contextmanager
def count_device_put():
batched_device_put = pxla.batched_device_put
count = [0]
class EventThreadLocalState(threading.local):
def __init__(self):
self.counts = {} # Mapping from string name to count.
self.nested_device_put_count = 0 # Number of recursive calls to device_put
def make_fn_and_count(fn):
def fn_and_count(*args, **kwargs):
count[0] += 1
# device_put handlers might call `dispatch.device_put` (e.g. on an
# underlying payload or several). We only want to count these
# recursive puts once, so we skip counting more than the outermost
# one in such a call stack.
pxla.batched_device_put = batched_device_put
try:
return fn(*args, **kwargs)
finally:
pxla.batched_device_put = batched_device_put_and_count
return fn_and_count
# Per-function counts
self.infer_params_fun_counts = None
self.lower_jaxpr_to_fun_counts = None
batched_device_put_and_count = make_fn_and_count(batched_device_put)
thread_local_state = EventThreadLocalState()
pxla.batched_device_put = batched_device_put_and_count
try:
yield count
finally:
pxla.batched_device_put = batched_device_put
def event_listener(name, *args):
counts = thread_local_state.counts
counts[name] = counts.get(name, 0) + 1
# device_put handlers might call `dispatch.device_put` (e.g. on an
# underlying payload or several). We only want to count these
# recursive puts once, so we skip counting more than the outermost
# one in such a call stack.
if name == "batched_device_put_start":
if thread_local_state.nested_device_put_count == 0:
counts["batched_device_put"] = counts.get("batched_device_put", 0) + 1
thread_local_state.nested_device_put_count += 1
elif name == "batched_device_put_end":
thread_local_state.nested_device_put_count -= 1
elif name == "pjit._infer_params_impl":
# For infer_params, we collect per-function data, but only while a context
# manager is active.
infer_counts = thread_local_state.infer_params_fun_counts
if infer_counts is not None:
(fun,) = args
infer_counts[fun] += 1
elif name == "lower_jaxpr_to_fun":
# For infer_params, we collect per-function data, but only while a context
# manager is active.
lower_counts = thread_local_state.lower_jaxpr_to_fun_counts
if lower_counts is not None:
(fun,) = args
lower_counts[fun] += 1
util.test_event_listener = event_listener
def count_events(event):
"Returns a context-manager that yields a function that counts a test event."
@contextmanager
def count_event():
before = thread_local_state.counts.get(event, 0)
yield lambda: thread_local_state.counts.get(event, 0) - before
return count_event
count_device_put = count_events("batched_device_put")
count_device_put_fast_path_hit = count_events("batched_copy_array")
count_pjit_cpp_cache_miss = count_events("pjit_lower")
count_jit_tracing_cache_miss = count_events("create_pjit_jaxpr")
count_aot_jit_cpp_cache_miss = count_events("stages_compiled_call")
count_jit_and_pmap_lowerings = count_events("lower_jaxpr_to_module")
count_jit_compilation_cache_miss = count_events("pxla_cached_compilation")
count_jax_array_shard_arg_calls = count_events("_array_shard_arg")
@contextmanager
@ -269,189 +301,39 @@ def count_primitive_compiles():
count = [-1]
try:
yield count
yield lambda: count[0]
finally:
count[0] = dispatch.xla_primitive_callable.cache_info().misses
@contextmanager
def count_device_put_fast_path_hit():
original_fn = xc.batched_copy_array_to_devices_with_sharding
count = [0]
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
count[0] += 1
return original_fn(*args, **kwargs)
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
try:
yield count
finally:
xc.batched_copy_array_to_devices_with_sharding = original_fn
@contextmanager
def count_pjit_cpp_cache_miss():
original_pjit_lower = pjit_lib._pjit_lower
count = [0]
def pjit_lower_and_count(*args, **kwargs):
count[0] += 1
return original_pjit_lower(*args, **kwargs)
pjit_lib._pjit_lower = pjit_lower_and_count
try:
yield count
finally:
pjit_lib._pjit_lower = original_pjit_lower
@contextmanager
def count_cached_compilation_cache_miss():
original_cached_compilation = pxla._cached_compilation
count = [0]
def cached_compilation_and_count(*args, **kwargs):
count[0] += 1
return original_cached_compilation(*args, **kwargs)
pxla._cached_compilation = cached_compilation_and_count
try:
yield count
finally:
pxla._cached_compilation = original_cached_compilation
@contextmanager
def count_jit_tracing_cache_miss():
original_create_pjit_jaxpr = pjit_lib._create_pjit_jaxpr
count = [0]
@lu.cache
def create_pjit_jaxpr_and_count(*args):
count[0] += 1
return original_create_pjit_jaxpr(*args)
pjit_lib._create_pjit_jaxpr = create_pjit_jaxpr_and_count
try:
yield count
finally:
pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr
@contextmanager
def count_jit_infer_params_cache_miss():
original_infer_params_impl = pjit_lib._infer_params_impl
count = collections.defaultdict(int)
def infer_params_impl_and_count(fun, *args, **kw):
count[fun] += 1
return original_infer_params_impl(fun, *args, **kw)
pjit_lib._infer_params_impl = infer_params_impl_and_count
assert thread_local_state.infer_params_fun_counts is None
counts = collections.defaultdict(int)
thread_local_state.infer_params_fun_counts = counts
try:
yield count
yield counts
finally:
pjit_lib._infer_params_impl = original_infer_params_impl
thread_local_state.infer_params_fun_counts = None
@contextmanager
def count_aot_jit_cpp_cache_miss():
original_call = stages.Compiled.call
count = [0]
def compiled_call_count(*args, **kwargs):
count[0] += 1
return original_call(*args, **kwargs)
stages.Compiled.call = compiled_call_count
def count_subjaxpr_to_hlo_conversion(fun_name):
assert thread_local_state.lower_jaxpr_to_fun_counts is None
counts = collections.defaultdict(int)
thread_local_state.lower_jaxpr_to_fun_counts = counts
try:
yield count
yield lambda: counts[fun_name]
finally:
stages.Compiled.call = original_call
thread_local_state.lower_jaxpr_to_fun_counts = None
@contextmanager
def count_jit_and_pmap_lowerings():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
mlir_lower = mlir.lower_jaxpr_to_module
count = [0]
def mlir_lower_and_count(*args, **kwargs):
count[0] += 1
return mlir_lower(*args, **kwargs)
mlir.lower_jaxpr_to_module = mlir_lower_and_count
try:
yield count
finally:
mlir.lower_jaxpr_to_module = mlir_lower
@contextmanager
def count_jax_array_shard_arg_calls():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
array_shard_arg = array._array_shard_arg
count = [0]
def array_shard_arg_and_count(*args, **kwargs):
count[0] += 1
return array_shard_arg(*args, **kwargs)
pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count
try:
yield count
finally:
pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg
@contextmanager
def count_jit_compilation_cache_miss():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
jit_compilation = pxla._cached_compilation
count = [0]
def compile_and_count(*args, **kwargs):
count[0] += 1
return jit_compilation(*args, **kwargs)
pxla._cached_compilation = compile_and_count
try:
yield count
finally:
pxla._cached_compilation = jit_compilation
@contextmanager
def count_subjaxpr_to_hlo_conversion(fun_name: str):
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
mlir_lower = mlir.lower_jaxpr_to_fun
count = [0]
def mlir_lower_and_count(ctx, name, *args, **kwargs):
if name == fun_name:
count[0] += 1
return mlir_lower(ctx, name, *args, **kwargs)
mlir.lower_jaxpr_to_fun = mlir_lower_and_count
try:
yield count
finally:
mlir.lower_jaxpr_to_fun = mlir_lower
@contextmanager
def assert_num_jit_and_pmap_compilations(times):
with count_jit_and_pmap_lowerings() as count:
yield
if count[0] != times:
if count() != times:
raise AssertionError(f"Expected exactly {times} XLA compilations, "
f"but executed {count[0]}")
f"but executed {count()}")
def jaxlib_version() -> tuple[int, ...]:

View File

@ -676,3 +676,12 @@ class StrictABCMeta(abc.ABCMeta):
class StrictABC(metaclass=StrictABCMeta):
__slots__ = ()
test_event_listener: Callable | None = None
def test_event(name: str, *args) -> None:
if not test_event_listener:
return
test_event_listener(name, *args)

View File

@ -24,6 +24,7 @@ from typing import Any, Callable, Sequence
import jax
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
@ -283,6 +284,7 @@ def _get_specialized_func(
info: FunctionInfo, specialization: Specialization
) -> Callable[..., Any]:
"""Returns a specialized function for the given specialization."""
util.test_event("colocated_python_func._get_specialized_func")
assert specialization.in_specs_treedef is not None
assert specialization.in_specs_leaves is not None
assert specialization.devices is not None

View File

@ -1081,7 +1081,7 @@ class JitTest(jtu.BufferDonationTestCase):
args = range(num_args)
with jtu.count_device_put() as count:
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def testBuffersAreFreedPromptly(self):
# Regression test for a bug where garbage collection was delayed too long
@ -1246,7 +1246,7 @@ class JitTest(jtu.BufferDonationTestCase):
jitted_f = jit(lambda x, y: x, keep_unused=True)
with jtu.count_pjit_cpp_cache_miss() as count:
_ = jitted_f(1, 2)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_jit_lower_compile_compiler_ir(self):
f = jit(lambda x: x + 4).lower(1.).compile()
@ -1428,7 +1428,7 @@ class JitTest(jtu.BufferDonationTestCase):
with jtu.count_jit_compilation_cache_miss() as count:
jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.)
jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.)
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
# We should still error on invalid options after some valid compiles
with self.assertRaisesRegex(
@ -1511,7 +1511,7 @@ class JitTest(jtu.BufferDonationTestCase):
expected = f()
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = jax.vmap(f, axis_size=2, out_axes=None)()
self.assertEqual(count[0], 0) # no compiles
self.assertEqual(count(), 0) # no compiles
self.assertArraysAllClose(ans, expected, check_dtypes=True)
def test_cache_key_defaults(self):
@ -2737,7 +2737,7 @@ class APITest(jtu.JaxTestCase):
jax.eval_shape(f, inp)
jax.jit(f)(inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_jit_infer_params_cache(self):
def f(x):
@ -3384,12 +3384,12 @@ class APITest(jtu.JaxTestCase):
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
_ = jax.grad(f)(3.)
self.assertEqual(count[0], 2) # one for fwd, one for bwd
self.assertEqual(count(), 2) # one for fwd, one for bwd
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
_ = jax.grad(f)(3.)
_ = jax.grad(f)(4.)
self.assertEqual(count[0], 0) # cache hits on both fwd and bwd
self.assertEqual(count(), 0) # cache hits on both fwd and bwd
def test_grad_does_not_unflatten_tree_with_none(self):
# https://github.com/jax-ml/jax/issues/7546
@ -3458,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
with jtu.count_primitive_compiles() as count:
lax.add(1, 2)
lax.add(2, 3)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_arange_jit(self):
# see https://github.com/jax-ml/jax/issues/553
@ -4021,7 +4021,7 @@ class APITest(jtu.JaxTestCase):
jax.eval_shape(jax.numpy.array, 1)
out = jax.eval_shape(jax.numpy.array, 1)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertTrue(out.weak_type)
self.assertEqual(out.weak_type, arr.weak_type)
@ -4296,19 +4296,19 @@ class APITest(jtu.JaxTestCase):
for _ in range(5):
jax.hessian(jf)(x).block_until_ready()
n = count[0]
n = count()
# The exact number of compilations may vary depending on the number of
# jit decorators in the function above, but it should not grow after an
# initial warmup phase.
for _ in range(5):
jax.hessian(jf)(x).block_until_ready()
self.assertEqual(count[0], n)
self.assertEqual(count(), n)
def test_jnp_array_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)
self.assertEqual(count(), 0)
def test_rank_promotion_forces_retrace(self):
num_traces = 0
@ -5910,7 +5910,7 @@ class RematTest(jtu.JaxTestCase):
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
f_lin(1.).block_until_ready()
self.assertEqual(count[0], 1) # cached after first execution
self.assertEqual(count(), 1) # cached after first execution
def test_vjp_caching(self):
# https://github.com/jax-ml/jax/issues/9661
@ -5919,7 +5919,7 @@ class RematTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd
def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
@ -5928,7 +5928,7 @@ class RematTest(jtu.JaxTestCase):
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd
def test_fwd_caching(self):
# see above test also
@ -5937,7 +5937,7 @@ class RematTest(jtu.JaxTestCase):
for _ in range(20):
y, _ = jax.vjp(identity, 1.)
y.block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_fwd_caching_static_argnums(self):
# see above test also
@ -5946,7 +5946,7 @@ class RematTest(jtu.JaxTestCase):
for _ in range(20):
y = identity(1.)
y.block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
@parameterized.named_parameters(
{"testcase_name": f"{suffix}", "remat": remat}

View File

@ -858,7 +858,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
_ = f(3.)
with jtu.count_jit_and_pmap_lowerings() as count:
_ = f(3.)
self.assertEqual(count[0], 0)
self.assertEqual(count(), 0)
def test_goodfellow_custom_jvp(self):
def h(fext):

View File

@ -24,7 +24,6 @@ from jax._src import config
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
from jax.experimental import colocated_python
from jax.experimental.colocated_python import func as colocated_python_func
from jax.experimental.colocated_python import serialization
from jax.extend.ifrt_programs import ifrt_programs
import jax.numpy as jnp
@ -52,23 +51,8 @@ def _colocated_cpu_devices(
]
@contextlib.contextmanager
def _count_colocated_python_specialization_cache_miss() -> list[int]:
"""Counts the number of cache misses for colocated_python specialization."""
original_get_specialized_func = colocated_python_func._get_specialized_func
count = [0]
@jax.util.cache(max_size=None)
def get_specialized_func(*args, **kwargs):
count[0] += 1
return original_get_specialized_func(*args, **kwargs)
colocated_python_func._get_specialized_func = get_specialized_func
try:
yield count
finally:
colocated_python_func._get_specialized_func = original_get_specialized_func
_count_colocated_python_specialization_cache_miss = jtu.count_events(
"colocated_python_func._get_specialized_func")
_exit_stack = contextlib.ExitStack()
@ -117,12 +101,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def testSimpleFunctioWithTree(self):
@colocated_python.colocated_python
@ -137,12 +121,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def testEmptyInputFailsWithoutSpecialization(self):
@colocated_python.colocated_python
@ -168,12 +152,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = make_zero()
out = jax.device_get(out)
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
out = make_zero()
out = jax.device_get(out)
self.assertEqual(out, np.array(0))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def testInputPolymorphismWithoutOutSpecsFn(self):
@colocated_python.colocated_python
@ -188,12 +172,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
# Different input tree structure and dtype/shape.
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
@ -202,12 +186,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
def testInputPolymorphismAllowedWithOutSpecsFn(self):
@colocated_python.colocated_python
@ -223,12 +207,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, np.array(2))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
# Different input tree structure and dtype/shape.
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
@ -237,12 +221,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
out = add_one(x)
out = jax.device_get(out)
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
@parameterized.named_parameters(
("on_main_thread", True),

View File

@ -2621,7 +2621,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.cond(x, f, g, x)
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond
# In #14058, this was observed to be 31 compiles.
self.assertLess(count[0], 5)
self.assertLess(count(), 5)
@parameterized.named_parameters(
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}

View File

@ -87,7 +87,7 @@ class LayoutTest(jtu.JaxTestCase):
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
init_out = init_compiled(arr1, arr2)
init_compiled(arr1, arr2)
self.assertEqual(init_count[0], 1)
self.assertEqual(init_count(), 1)
self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0])
self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1])
@ -95,7 +95,7 @@ class LayoutTest(jtu.JaxTestCase):
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
apply_out = compiled_apply(*init_out)
compiled_apply(*init_out)
self.assertEqual(apply_count[0], 1)
self.assertEqual(apply_count(), 1)
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0])
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1])

View File

@ -1308,7 +1308,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(inp)
out2 = f(inp2)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)
@ -1329,8 +1329,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
jtu.count_jit_and_pmap_lowerings() as compile_count):
f(inp)
f(inp2)
self.assertEqual(cpp_count[0], 2)
self.assertEqual(compile_count[0], 1)
self.assertEqual(cpp_count(), 2)
self.assertEqual(compile_count(), 1)
def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
@ -1342,7 +1342,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
out = mul_two(inp)
mul_two(out)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
@ -1357,7 +1357,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
with jtu.count_jit_and_pmap_lowerings() as count:
out = f(np_inp)
out2 = g(np_inp2)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
@ -1633,7 +1633,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
f(inp)
# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
self.assertEqual(count[0], 4)
self.assertEqual(count(), 4)
def test_offload_take_host(self):
@compute_on('device_host')

View File

@ -141,7 +141,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
y = lax.add(x, x)
z = lax.add(y, y)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assert_committed_to_device(y, devices[1])
self.assert_committed_to_device(z, devices[1])

View File

@ -149,14 +149,14 @@ class PgleTest(jtu.JaxTestCase):
with config.pgle_profiling_runs(2), config.enable_pgle(True):
# Run 1: Module should be compiled without FDO. Two modules are expected
# One is the funtion f, the other one is multi slice module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
self.assertEqual(cache_miss_count(), 2)
# Run 2: Second PGLE run. Profile should be empty.
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
self.assertEqual(cache_miss_count(), 2)
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
# One for before and one for after optimization.
self.assertLen(fdo_profiles_before_pgle, 2)
@ -165,9 +165,9 @@ class PgleTest(jtu.JaxTestCase):
os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0)
# Run 3: The module should be recompiled with FDO profiles
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 2)
self.assertEqual(cache_miss_count(), 2)
fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir)
# One for before and one for after optimization.
self.assertLen(fdo_profiles_after_pgle, 4)
@ -179,9 +179,9 @@ class PgleTest(jtu.JaxTestCase):
)
# Run 4: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertLess(cache_miss_count[0], 2)
self.assertLess(cache_miss_count(), 2)
def testAutoPgleWithAot(self):
@jax.jit
@ -197,14 +197,14 @@ class PgleTest(jtu.JaxTestCase):
with config.pgle_profiling_runs(1), config.enable_pgle(True):
# Run 1
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
self.assertEqual(cache_miss_count(), 0)
# Run 2
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(compiled(x), expected)
self.assertEqual(cache_miss_count[0], 0)
self.assertEqual(cache_miss_count(), 0)
def testAutoPgleWithPersistentCache(self):
its = 50
@ -243,24 +243,24 @@ class PgleTest(jtu.JaxTestCase):
cc.reset_cache()
cc.set_cache_dir(cache_dir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)
self.assertGreater(cache_miss_count(), 0)
# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(cache_dir)
self.assertNotEmpty(non_pgle_profiled_files)
# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)
self.assertGreater(cache_miss_count(), 0)
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)
self.assertGreater(cache_miss_count(), 0)
# Check if FDO profile file of the biggest module is not empty
fdo_profiles_after_pgle = [

View File

@ -660,7 +660,7 @@ class PJitTest(jtu.BufferDonationTestCase):
jax.grad(f)(x) # Warm up the cache.
with jtu.count_pjit_cpp_cache_miss() as count:
jax.grad(f)(x)
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
self.assertEqual(count(), 0) # no cache miss i.e. cache hit
@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
@ -2200,7 +2200,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
out = f(a)
_ = f(out)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_pjit_different_device_recompilation(self):
if jax.device_count() < 2:
@ -2217,7 +2217,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_jit_compilation_cache_miss() as count:
out1 = f(a)
out2 = f(b)
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
self.assertArraysEqual(out1, val1)
self.assertArraysEqual(out2, val2)
@ -2645,7 +2645,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
inp_data, jax.sharding.NamedSharding(mesh, P('x')))
with mesh:
f(arr1)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_single_device_add_single_compile(self):
f1 = pjit(lambda x, y: x + y)
@ -2657,7 +2657,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(2):
f1(a, b)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_global_array_to_host_local_array_already_host_local(self):
inp_shape = (8, 2)
@ -2791,7 +2791,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
f_names = pjit(f, static_argnames='x')
f_names(y, x='foo')
f_names(y, x='foo')
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_new_static_argnum_on_keyword_arguments(self):
f = pjit(lambda x: x, static_argnums=0)
@ -2834,7 +2834,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# The count here is 0 because before `count_pjit_cpp_cache_miss`, `f` was
# called with `system_default_device` and `test_device` so it was added
# to the cache. Subsequent calls hit the C++ cache.
self.assertEqual(count[0], 0)
self.assertEqual(count(), 0)
def test_pjit_with_mismatched_static_argnames(self):
x_is_tracer, y_is_tracer = False, False
@ -3287,7 +3287,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(f)(inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_pjit_no_global_cache_hit_axis_resources(self):
mesh = jtu.create_mesh((1,), ('x',))
@ -3297,20 +3297,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(inp)
self.assertEqual(count[0], 10)
self.assertEqual(count(), 10)
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
self.assertEqual(count[0], 10)
self.assertEqual(count(), 10)
pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf(inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
with jtu.ignore_warning(category=DeprecationWarning,
message="backend and device argument"):
@ -3318,7 +3318,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf1(inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_with_sharding_constraint_spmd_axis_name(self):
mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
@ -3484,7 +3484,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out = f(arr)
self.assertIsInstance(out.sharding, NamedSharding)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
with jtu.count_pjit_cpp_cache_miss() as count:
out2 = f2(arr)
@ -3493,7 +3493,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out2 = f2(arr)
self.assertIsInstance(out2.sharding, PositionalSharding)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
@ -3530,7 +3530,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out2.sharding, NamedSharding)
# Drops out of C++ cache i.e. cache miss
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
# Still gets a hit on pjit_lower cache.
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
@ -3600,7 +3600,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
out3 = vf(out2)
self.assertIsInstance(out3.sharding, NamedSharding)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_jit_mul_sum_sharding_preserved(self):
if config.use_shardy_partitioner.value:
@ -3625,7 +3625,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# This will hit the cpp cache.
out3 = f(out2)
self.assertIsInstance(out3.sharding, PositionalSharding)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)
@ -3755,7 +3755,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
top(jnp.arange(8))
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_wsc_eager(self):
mesh = jtu.create_mesh((2,), ('x',))
@ -3792,7 +3792,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
out1 = core.jaxpr_as_fun(jaxpr)(inp)
out2 = core.jaxpr_as_fun(jaxpr)(inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertArraysEqual(out1[0], inp * 2)
self.assertArraysEqual(out2[0], inp * 2)
@ -3970,7 +3970,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_jit_and_pmap_lowerings() as count:
g(np.arange(8))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_lowering_cache_miss_different_devices_and_sharding(self):
if jax.device_count() < 4:
@ -3993,7 +3993,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_jit_and_pmap_lowerings() as count:
g(np.arange(8))
self.assertEqual(count[0], 2)
self.assertEqual(count(), 2)
def test_single_device_named_sharding_preserved(self):
mesh = jax.sharding.Mesh([jax.devices()[0]], 'x')
@ -4021,7 +4021,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
# that is expected.
with jtu.count_device_put_fast_path_hit() as count:
out = jax.device_put(arr1, NamedSharding(mesh2, P('x')))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertTupleEqual(out.sharding._device_assignment,
mesh2._flat_devices_tuple)
self.assertArraysEqual(out, inp)
@ -4536,9 +4536,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # tracing and lowering cache *hit*
self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
def test_wsc_abstract_mesh(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
@ -4621,7 +4621,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
jax.jit(f, out_shardings=s)(np.arange(8))
jax.jit(f, out_shardings=s)(np.arange(8))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_input_shardings_single_device(self):
@jax.jit

View File

@ -1287,7 +1287,7 @@ class PythonPmapTest(jtu.JaxTestCase):
x = jnp.arange(device_count)
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
# self.assertEqual(count(), 0) # TODO(mattjj): fix this
expected = np.repeat(3, device_count)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -1308,7 +1308,7 @@ class PythonPmapTest(jtu.JaxTestCase):
x = jnp.arange(len(devices))
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
expected = np.repeat(3, len(devices))
self.assertAllClose(ans, expected, check_dtypes=False)
@ -1344,7 +1344,7 @@ class PythonPmapTest(jtu.JaxTestCase):
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
self.assertAllClose(ans, expected, check_dtypes=False)
@ -1370,7 +1370,7 @@ class PythonPmapTest(jtu.JaxTestCase):
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
expected = 3 * np.ones(shape[:2])
self.assertAllClose(ans, expected, check_dtypes=False)
@ -2043,7 +2043,7 @@ class PythonPmapTest(jtu.JaxTestCase):
_, f_bwd2 = jax.vjp(f, x)
_ = f_bwd(x)
_ = f_bwd2(x)
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
self.assertEqual(count(), 0) # cache hits on fwd and bwd
def testSizeOverflow(self):
if config.disable_jit.value:

View File

@ -588,7 +588,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
with jtu.count_primitive_compiles() as count:
for _ in range(3):
self.assertAllClose(2 * x, fun(x))
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
class PureCallbackTest(jtu.JaxTestCase):

View File

@ -1041,7 +1041,7 @@ class LaxRandomTest(jtu.JaxTestCase):
key = self.make_key(1).block_until_ready()
with jtu.count_device_put() as count:
jax.jit(random.split)(key)
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
self.assertLessEqual(count(), 1) # 1 for the argument device_put
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
def test_randint_bounds(self, dtype):

View File

@ -709,7 +709,7 @@ class KeyArrayTest(jtu.JaxTestCase):
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
# TODO(jakevdp) remove this decorator when reuse checks move to C++
@jax.debug_key_reuse(False)
@ -726,7 +726,7 @@ class KeyArrayTest(jtu.JaxTestCase):
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_cpp_dispatch_aot_normal(self):
# Ensure we stay on the C++ dispatch path when calling an
@ -739,7 +739,7 @@ class KeyArrayTest(jtu.JaxTestCase):
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
def test_cpp_dispatch_aot_split(self):
# Ensure we stay on the C++ dispatch path when calling an
@ -753,7 +753,7 @@ class KeyArrayTest(jtu.JaxTestCase):
f(key).block_until_ready()
f(key).block_until_ready()
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
# -- prng primitives

View File

@ -201,7 +201,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
f(np_inp)
out1, out2 = f(np_inp)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertTrue(s.is_equivalent_to(out1.sharding, np_inp.ndim))
self.assertTrue(s.is_equivalent_to(out2.sharding, np_inp.ndim))
@ -213,7 +213,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
with jtu.count_pjit_cpp_cache_miss() as count:
g(arr)
out3, out4 = g(arr)
self.assertEqual(count[0], 1)
self.assertEqual(count(), 1)
self.assertEqual(out3.sharding, s)
self.assertEqual(out4.sharding, s)

View File

@ -825,9 +825,9 @@ class ShardMapTest(jtu.JaxTestCase):
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
f(b) # tracing and lowering cache *hit*
self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(lowering_count[0], 1)
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
self.assertEqual(lowering_count(), 1)
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
def test_shmap_abstract_mesh_errors(self):
mesh = jtu.create_mesh((2,), ('x',))