mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Don't monkey-patch functions in test_utils to count events for tests.
This has two problems: * it's not thread-safe, which will become problematic if we run tests with thread-parallelism. * it's not very maintainable. Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
This commit is contained in:
parent
3630756e87
commit
62e66b684b
@ -37,10 +37,10 @@ class AttrsTests(jtu.JaxTestCase):
|
||||
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
jit_array_attr(5)
|
||||
self.assertEqual(count[0], 1) # compiles once the first time
|
||||
self.assertEqual(count(), 1) # compiles once the first time
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
jit_array_attr(5)
|
||||
self.assertEqual(count[0], 0) # cache hit
|
||||
self.assertEqual(count(), 0) # cache hit
|
||||
|
||||
def test_array_attr_no_jit(self):
|
||||
with jax.disable_jit():
|
||||
|
@ -33,6 +33,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import errors
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
|
||||
|
||||
|
||||
def _array_shard_arg(xs, shardings, layouts, copy_semantics):
|
||||
util.test_event("_array_shard_arg")
|
||||
results = []
|
||||
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
|
||||
batch_cs = []
|
||||
@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
|
||||
results.append(
|
||||
shard_sharded_device_array_slow_path(x, devices, indices, sharding))
|
||||
|
||||
util.test_event("batched_copy_array")
|
||||
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
|
||||
batch_xs, batch_devs, batch_shardings, batch_cs)
|
||||
for i, copy_out in safe_zip(batch_indices, copy_outs):
|
||||
|
@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(prim: core.Primitive, **params):
|
||||
util.test_event("xla_primitive_callable_cache_miss")
|
||||
def prim_fun(*args):
|
||||
with config.eager_constant_folding(False):
|
||||
return prim.bind(*args, **params)
|
||||
|
@ -1085,6 +1085,7 @@ def lower_jaxpr_to_module(
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_module")
|
||||
platforms = tuple(map(xb.canonicalize_platform, platforms))
|
||||
|
||||
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
||||
@ -1378,6 +1379,7 @@ def lower_jaxpr_to_fun(
|
||||
Returns:
|
||||
MLIR func op
|
||||
"""
|
||||
util.test_event("lower_jaxpr_to_fun", name)
|
||||
|
||||
# The first dimension variable may be the platform index
|
||||
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
|
||||
|
@ -231,16 +231,20 @@ shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: JSharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
from jax._src import array
|
||||
util.test_event("batched_device_put_start")
|
||||
try:
|
||||
from jax._src import array
|
||||
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
bufs = [x for x, d in safe_zip(xs, devices)
|
||||
if (isinstance(x, array.ArrayImpl) and
|
||||
dispatch.is_single_device_sharding(x.sharding) and
|
||||
x.devices() == {d})]
|
||||
if len(bufs) == len(xs):
|
||||
return array.ArrayImpl(
|
||||
aval, sharding, bufs, committed=committed, _skip_checks=True)
|
||||
return xc.batched_device_put(aval, sharding, xs, list(devices), committed)
|
||||
finally:
|
||||
util.test_event("batched_device_put_end")
|
||||
|
||||
def _shard_aval(size, axis: int, aval):
|
||||
try:
|
||||
@ -2850,6 +2854,7 @@ class UnloadedMeshExecutable:
|
||||
mesh = i.mesh
|
||||
break
|
||||
|
||||
util.test_event("pxla_cached_compilation")
|
||||
xla_executable = _cached_compilation(
|
||||
hlo, name, mesh, spmd_lowering,
|
||||
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
|
||||
|
@ -549,6 +549,7 @@ def _infer_params_impl(
|
||||
kwargs: dict[str, Any],
|
||||
in_avals: tuple[core.AbstractValue, ...] | None,
|
||||
) -> tuple[PjitParams, list[Any]]:
|
||||
util.test_event("pjit._infer_params_impl", fun)
|
||||
have_kwargs = bool(kwargs)
|
||||
if have_kwargs and ji.user_specified_in_shardings:
|
||||
raise ValueError(
|
||||
@ -1297,6 +1298,7 @@ def _create_pjit_jaxpr(
|
||||
ignored_inline: IgnoreKey
|
||||
) -> tuple[core.ClosedJaxpr, list[Any], list[core.AbstractValue],
|
||||
list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
|
||||
util.test_event("create_pjit_jaxpr")
|
||||
del ignored_inline # just for explain_cache_miss
|
||||
if config.no_tracing.value:
|
||||
raise RuntimeError(f"re-tracing function {fun.f} for `jit`, but "
|
||||
@ -1784,6 +1786,7 @@ def _pjit_lower(
|
||||
lowering_platforms: tuple[str, ...] | None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if config.sharding_in_types.value:
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
else:
|
||||
|
@ -533,6 +533,7 @@ class Compiled(Stage):
|
||||
|
||||
@staticmethod
|
||||
def call(*args, **kwargs):
|
||||
util.test_event("stages_compiled_call")
|
||||
# This is because `__call__` passes in `self._params` as the first argument.
|
||||
# Instead of making the call signature `call(params, *args, **kwargs)`
|
||||
# extract it from args because `params` can be passed as a kwarg by users
|
||||
|
@ -30,6 +30,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import threading
|
||||
from typing import Any, TextIO
|
||||
import unittest
|
||||
import warnings
|
||||
@ -40,22 +41,17 @@ from absl.testing import parameterized
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import api
|
||||
from jax._src import array
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import lib as _jaxlib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import monitoring
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src import stages
|
||||
from jax._src import xla_bridge
|
||||
from jax._src import util
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
@ -235,32 +231,68 @@ capture_stdout = partial(_capture_output, sys.stdout)
|
||||
capture_stderr = partial(_capture_output, sys.stderr)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_device_put():
|
||||
batched_device_put = pxla.batched_device_put
|
||||
count = [0]
|
||||
class EventThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.counts = {} # Mapping from string name to count.
|
||||
self.nested_device_put_count = 0 # Number of recursive calls to device_put
|
||||
|
||||
def make_fn_and_count(fn):
|
||||
def fn_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
# device_put handlers might call `dispatch.device_put` (e.g. on an
|
||||
# underlying payload or several). We only want to count these
|
||||
# recursive puts once, so we skip counting more than the outermost
|
||||
# one in such a call stack.
|
||||
pxla.batched_device_put = batched_device_put
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pxla.batched_device_put = batched_device_put_and_count
|
||||
return fn_and_count
|
||||
# Per-function counts
|
||||
self.infer_params_fun_counts = None
|
||||
self.lower_jaxpr_to_fun_counts = None
|
||||
|
||||
batched_device_put_and_count = make_fn_and_count(batched_device_put)
|
||||
thread_local_state = EventThreadLocalState()
|
||||
|
||||
pxla.batched_device_put = batched_device_put_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pxla.batched_device_put = batched_device_put
|
||||
|
||||
def event_listener(name, *args):
|
||||
counts = thread_local_state.counts
|
||||
counts[name] = counts.get(name, 0) + 1
|
||||
|
||||
# device_put handlers might call `dispatch.device_put` (e.g. on an
|
||||
# underlying payload or several). We only want to count these
|
||||
# recursive puts once, so we skip counting more than the outermost
|
||||
# one in such a call stack.
|
||||
if name == "batched_device_put_start":
|
||||
if thread_local_state.nested_device_put_count == 0:
|
||||
counts["batched_device_put"] = counts.get("batched_device_put", 0) + 1
|
||||
thread_local_state.nested_device_put_count += 1
|
||||
elif name == "batched_device_put_end":
|
||||
thread_local_state.nested_device_put_count -= 1
|
||||
|
||||
elif name == "pjit._infer_params_impl":
|
||||
# For infer_params, we collect per-function data, but only while a context
|
||||
# manager is active.
|
||||
infer_counts = thread_local_state.infer_params_fun_counts
|
||||
if infer_counts is not None:
|
||||
(fun,) = args
|
||||
infer_counts[fun] += 1
|
||||
elif name == "lower_jaxpr_to_fun":
|
||||
# For infer_params, we collect per-function data, but only while a context
|
||||
# manager is active.
|
||||
lower_counts = thread_local_state.lower_jaxpr_to_fun_counts
|
||||
if lower_counts is not None:
|
||||
(fun,) = args
|
||||
lower_counts[fun] += 1
|
||||
|
||||
|
||||
util.test_event_listener = event_listener
|
||||
|
||||
|
||||
def count_events(event):
|
||||
"Returns a context-manager that yields a function that counts a test event."
|
||||
@contextmanager
|
||||
def count_event():
|
||||
before = thread_local_state.counts.get(event, 0)
|
||||
yield lambda: thread_local_state.counts.get(event, 0) - before
|
||||
return count_event
|
||||
|
||||
count_device_put = count_events("batched_device_put")
|
||||
count_device_put_fast_path_hit = count_events("batched_copy_array")
|
||||
count_pjit_cpp_cache_miss = count_events("pjit_lower")
|
||||
count_jit_tracing_cache_miss = count_events("create_pjit_jaxpr")
|
||||
count_aot_jit_cpp_cache_miss = count_events("stages_compiled_call")
|
||||
count_jit_and_pmap_lowerings = count_events("lower_jaxpr_to_module")
|
||||
count_jit_compilation_cache_miss = count_events("pxla_cached_compilation")
|
||||
count_jax_array_shard_arg_calls = count_events("_array_shard_arg")
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -269,189 +301,39 @@ def count_primitive_compiles():
|
||||
|
||||
count = [-1]
|
||||
try:
|
||||
yield count
|
||||
yield lambda: count[0]
|
||||
finally:
|
||||
count[0] = dispatch.xla_primitive_callable.cache_info().misses
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_device_put_fast_path_hit():
|
||||
original_fn = xc.batched_copy_array_to_devices_with_sharding
|
||||
count = [0]
|
||||
|
||||
def batched_copy_array_to_devices_with_sharding_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
xc.batched_copy_array_to_devices_with_sharding = batched_copy_array_to_devices_with_sharding_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xc.batched_copy_array_to_devices_with_sharding = original_fn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_pjit_cpp_cache_miss():
|
||||
original_pjit_lower = pjit_lib._pjit_lower
|
||||
count = [0]
|
||||
|
||||
def pjit_lower_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_pjit_lower(*args, **kwargs)
|
||||
|
||||
pjit_lib._pjit_lower = pjit_lower_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pjit_lib._pjit_lower = original_pjit_lower
|
||||
|
||||
@contextmanager
|
||||
def count_cached_compilation_cache_miss():
|
||||
original_cached_compilation = pxla._cached_compilation
|
||||
count = [0]
|
||||
|
||||
def cached_compilation_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_cached_compilation(*args, **kwargs)
|
||||
|
||||
pxla._cached_compilation = cached_compilation_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pxla._cached_compilation = original_cached_compilation
|
||||
|
||||
@contextmanager
|
||||
def count_jit_tracing_cache_miss():
|
||||
original_create_pjit_jaxpr = pjit_lib._create_pjit_jaxpr
|
||||
count = [0]
|
||||
|
||||
@lu.cache
|
||||
def create_pjit_jaxpr_and_count(*args):
|
||||
count[0] += 1
|
||||
return original_create_pjit_jaxpr(*args)
|
||||
|
||||
pjit_lib._create_pjit_jaxpr = create_pjit_jaxpr_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pjit_lib._create_pjit_jaxpr = original_create_pjit_jaxpr
|
||||
|
||||
@contextmanager
|
||||
def count_jit_infer_params_cache_miss():
|
||||
original_infer_params_impl = pjit_lib._infer_params_impl
|
||||
count = collections.defaultdict(int)
|
||||
|
||||
def infer_params_impl_and_count(fun, *args, **kw):
|
||||
count[fun] += 1
|
||||
return original_infer_params_impl(fun, *args, **kw)
|
||||
|
||||
pjit_lib._infer_params_impl = infer_params_impl_and_count
|
||||
assert thread_local_state.infer_params_fun_counts is None
|
||||
counts = collections.defaultdict(int)
|
||||
thread_local_state.infer_params_fun_counts = counts
|
||||
try:
|
||||
yield count
|
||||
yield counts
|
||||
finally:
|
||||
pjit_lib._infer_params_impl = original_infer_params_impl
|
||||
|
||||
thread_local_state.infer_params_fun_counts = None
|
||||
|
||||
@contextmanager
|
||||
def count_aot_jit_cpp_cache_miss():
|
||||
original_call = stages.Compiled.call
|
||||
count = [0]
|
||||
|
||||
def compiled_call_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_call(*args, **kwargs)
|
||||
|
||||
stages.Compiled.call = compiled_call_count
|
||||
def count_subjaxpr_to_hlo_conversion(fun_name):
|
||||
assert thread_local_state.lower_jaxpr_to_fun_counts is None
|
||||
counts = collections.defaultdict(int)
|
||||
thread_local_state.lower_jaxpr_to_fun_counts = counts
|
||||
try:
|
||||
yield count
|
||||
yield lambda: counts[fun_name]
|
||||
finally:
|
||||
stages.Compiled.call = original_call
|
||||
thread_local_state.lower_jaxpr_to_fun_counts = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_jit_and_pmap_lowerings():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
mlir_lower = mlir.lower_jaxpr_to_module
|
||||
count = [0]
|
||||
|
||||
def mlir_lower_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return mlir_lower(*args, **kwargs)
|
||||
|
||||
mlir.lower_jaxpr_to_module = mlir_lower_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
mlir.lower_jaxpr_to_module = mlir_lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_jax_array_shard_arg_calls():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
array_shard_arg = array._array_shard_arg
|
||||
count = [0]
|
||||
|
||||
def array_shard_arg_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return array_shard_arg(*args, **kwargs)
|
||||
|
||||
pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_jit_compilation_cache_miss():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
jit_compilation = pxla._cached_compilation
|
||||
count = [0]
|
||||
|
||||
def compile_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return jit_compilation(*args, **kwargs)
|
||||
|
||||
pxla._cached_compilation = compile_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
pxla._cached_compilation = jit_compilation
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_subjaxpr_to_hlo_conversion(fun_name: str):
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
mlir_lower = mlir.lower_jaxpr_to_fun
|
||||
count = [0]
|
||||
|
||||
def mlir_lower_and_count(ctx, name, *args, **kwargs):
|
||||
if name == fun_name:
|
||||
count[0] += 1
|
||||
return mlir_lower(ctx, name, *args, **kwargs)
|
||||
|
||||
mlir.lower_jaxpr_to_fun = mlir_lower_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
mlir.lower_jaxpr_to_fun = mlir_lower
|
||||
|
||||
|
||||
@contextmanager
|
||||
def assert_num_jit_and_pmap_compilations(times):
|
||||
with count_jit_and_pmap_lowerings() as count:
|
||||
yield
|
||||
if count[0] != times:
|
||||
if count() != times:
|
||||
raise AssertionError(f"Expected exactly {times} XLA compilations, "
|
||||
f"but executed {count[0]}")
|
||||
f"but executed {count()}")
|
||||
|
||||
|
||||
def jaxlib_version() -> tuple[int, ...]:
|
||||
|
@ -676,3 +676,12 @@ class StrictABCMeta(abc.ABCMeta):
|
||||
|
||||
class StrictABC(metaclass=StrictABCMeta):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
|
||||
test_event_listener: Callable | None = None
|
||||
|
||||
def test_event(name: str, *args) -> None:
|
||||
if not test_event_listener:
|
||||
return
|
||||
test_event_listener(name, *args)
|
||||
|
@ -24,6 +24,7 @@ from typing import Any, Callable, Sequence
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.traceback_util import api_boundary
|
||||
@ -283,6 +284,7 @@ def _get_specialized_func(
|
||||
info: FunctionInfo, specialization: Specialization
|
||||
) -> Callable[..., Any]:
|
||||
"""Returns a specialized function for the given specialization."""
|
||||
util.test_event("colocated_python_func._get_specialized_func")
|
||||
assert specialization.in_specs_treedef is not None
|
||||
assert specialization.in_specs_leaves is not None
|
||||
assert specialization.devices is not None
|
||||
|
@ -1081,7 +1081,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
args = range(num_args)
|
||||
with jtu.count_device_put() as count:
|
||||
np.testing.assert_allclose(f_pruned(*args), 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def testBuffersAreFreedPromptly(self):
|
||||
# Regression test for a bug where garbage collection was delayed too long
|
||||
@ -1246,7 +1246,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
jitted_f = jit(lambda x, y: x, keep_unused=True)
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
_ = jitted_f(1, 2)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_jit_lower_compile_compiler_ir(self):
|
||||
f = jit(lambda x: x + 4).lower(1.).compile()
|
||||
@ -1428,7 +1428,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
with jtu.count_jit_compilation_cache_miss() as count:
|
||||
jit(f, compiler_options={"xla_embed_ir_in_executable": True})(1.)
|
||||
jit(f, compiler_options={"xla_embed_ir_in_executable": False})(1.)
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
# We should still error on invalid options after some valid compiles
|
||||
with self.assertRaisesRegex(
|
||||
@ -1511,7 +1511,7 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
expected = f()
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = jax.vmap(f, axis_size=2, out_axes=None)()
|
||||
self.assertEqual(count[0], 0) # no compiles
|
||||
self.assertEqual(count(), 0) # no compiles
|
||||
self.assertArraysAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
def test_cache_key_defaults(self):
|
||||
@ -2737,7 +2737,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jax.eval_shape(f, inp)
|
||||
jax.jit(f)(inp)
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_jit_infer_params_cache(self):
|
||||
def f(x):
|
||||
@ -3384,12 +3384,12 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
_ = jax.grad(f)(3.)
|
||||
self.assertEqual(count[0], 2) # one for fwd, one for bwd
|
||||
self.assertEqual(count(), 2) # one for fwd, one for bwd
|
||||
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
_ = jax.grad(f)(3.)
|
||||
_ = jax.grad(f)(4.)
|
||||
self.assertEqual(count[0], 0) # cache hits on both fwd and bwd
|
||||
self.assertEqual(count(), 0) # cache hits on both fwd and bwd
|
||||
|
||||
def test_grad_does_not_unflatten_tree_with_none(self):
|
||||
# https://github.com/jax-ml/jax/issues/7546
|
||||
@ -3458,7 +3458,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
lax.add(1, 2)
|
||||
lax.add(2, 3)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_arange_jit(self):
|
||||
# see https://github.com/jax-ml/jax/issues/553
|
||||
@ -4021,7 +4021,7 @@ class APITest(jtu.JaxTestCase):
|
||||
jax.eval_shape(jax.numpy.array, 1)
|
||||
out = jax.eval_shape(jax.numpy.array, 1)
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assertTrue(out.weak_type)
|
||||
self.assertEqual(out.weak_type, arr.weak_type)
|
||||
|
||||
@ -4296,19 +4296,19 @@ class APITest(jtu.JaxTestCase):
|
||||
for _ in range(5):
|
||||
jax.hessian(jf)(x).block_until_ready()
|
||||
|
||||
n = count[0]
|
||||
n = count()
|
||||
# The exact number of compilations may vary depending on the number of
|
||||
# jit decorators in the function above, but it should not grow after an
|
||||
# initial warmup phase.
|
||||
for _ in range(5):
|
||||
jax.hessian(jf)(x).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], n)
|
||||
self.assertEqual(count(), n)
|
||||
|
||||
def test_jnp_array_doesnt_device_put(self):
|
||||
with jtu.count_device_put() as count:
|
||||
api.make_jaxpr(lambda: jnp.array(3))()
|
||||
self.assertEqual(count[0], 0)
|
||||
self.assertEqual(count(), 0)
|
||||
|
||||
def test_rank_promotion_forces_retrace(self):
|
||||
num_traces = 0
|
||||
@ -5910,7 +5910,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_lin(1.).block_until_ready()
|
||||
self.assertEqual(count[0], 1) # cached after first execution
|
||||
self.assertEqual(count(), 1) # cached after first execution
|
||||
|
||||
def test_vjp_caching(self):
|
||||
# https://github.com/jax-ml/jax/issues/9661
|
||||
@ -5919,7 +5919,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
|
||||
self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_vjp_caching_static_argnums(self):
|
||||
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
|
||||
@ -5928,7 +5928,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
for _ in range(20):
|
||||
f_vjp(1.)[0].block_until_ready()
|
||||
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
|
||||
self.assertEqual(count(), 2) # fwd execute_trivial, backward_pass on bwd
|
||||
|
||||
def test_fwd_caching(self):
|
||||
# see above test also
|
||||
@ -5937,7 +5937,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
for _ in range(20):
|
||||
y, _ = jax.vjp(identity, 1.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_fwd_caching_static_argnums(self):
|
||||
# see above test also
|
||||
@ -5946,7 +5946,7 @@ class RematTest(jtu.JaxTestCase):
|
||||
for _ in range(20):
|
||||
y = identity(1.)
|
||||
y.block_until_ready()
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"{suffix}", "remat": remat}
|
||||
|
@ -858,7 +858,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
_ = f(3.)
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
_ = f(3.)
|
||||
self.assertEqual(count[0], 0)
|
||||
self.assertEqual(count(), 0)
|
||||
|
||||
def test_goodfellow_custom_jvp(self):
|
||||
def h(fext):
|
||||
|
@ -24,7 +24,6 @@ from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
|
||||
from jax.experimental import colocated_python
|
||||
from jax.experimental.colocated_python import func as colocated_python_func
|
||||
from jax.experimental.colocated_python import serialization
|
||||
from jax.extend.ifrt_programs import ifrt_programs
|
||||
import jax.numpy as jnp
|
||||
@ -52,23 +51,8 @@ def _colocated_cpu_devices(
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _count_colocated_python_specialization_cache_miss() -> list[int]:
|
||||
"""Counts the number of cache misses for colocated_python specialization."""
|
||||
original_get_specialized_func = colocated_python_func._get_specialized_func
|
||||
count = [0]
|
||||
|
||||
@jax.util.cache(max_size=None)
|
||||
def get_specialized_func(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return original_get_specialized_func(*args, **kwargs)
|
||||
|
||||
colocated_python_func._get_specialized_func = get_specialized_func
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
colocated_python_func._get_specialized_func = original_get_specialized_func
|
||||
|
||||
_count_colocated_python_specialization_cache_miss = jtu.count_events(
|
||||
"colocated_python_func._get_specialized_func")
|
||||
|
||||
_exit_stack = contextlib.ExitStack()
|
||||
|
||||
@ -117,12 +101,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def testSimpleFunctioWithTree(self):
|
||||
@colocated_python.colocated_python
|
||||
@ -137,12 +121,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def testEmptyInputFailsWithoutSpecialization(self):
|
||||
@colocated_python.colocated_python
|
||||
@ -168,12 +152,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = make_zero()
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(0))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
out = make_zero()
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(0))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def testInputPolymorphismWithoutOutSpecsFn(self):
|
||||
@colocated_python.colocated_python
|
||||
@ -188,12 +172,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
# Different input tree structure and dtype/shape.
|
||||
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
|
||||
@ -202,12 +186,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
def testInputPolymorphismAllowedWithOutSpecsFn(self):
|
||||
@colocated_python.colocated_python
|
||||
@ -223,12 +207,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, np.array(2))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
# Different input tree structure and dtype/shape.
|
||||
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
|
||||
@ -237,12 +221,12 @@ class ColocatedPythonTest(jtu.JaxTestCase):
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
out = add_one(x)
|
||||
out = jax.device_get(out)
|
||||
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("on_main_thread", True),
|
||||
|
@ -2621,7 +2621,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lax.cond(x, f, g, x)
|
||||
# Should observe a maximum of 4 compiles: convert_element_type, f, g, cond
|
||||
# In #14058, this was observed to be 31 compiles.
|
||||
self.assertLess(count[0], 5)
|
||||
self.assertLess(count(), 5)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_dtype={dtype.__name__}", "dtype": dtype}
|
||||
|
@ -87,7 +87,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as init_count:
|
||||
init_out = init_compiled(arr1, arr2)
|
||||
init_compiled(arr1, arr2)
|
||||
self.assertEqual(init_count[0], 1)
|
||||
self.assertEqual(init_count(), 1)
|
||||
|
||||
self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0])
|
||||
self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1])
|
||||
@ -95,7 +95,7 @@ class LayoutTest(jtu.JaxTestCase):
|
||||
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
|
||||
apply_out = compiled_apply(*init_out)
|
||||
compiled_apply(*init_out)
|
||||
self.assertEqual(apply_count[0], 1)
|
||||
self.assertEqual(apply_count(), 1)
|
||||
|
||||
self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0])
|
||||
self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1])
|
||||
|
@ -1308,7 +1308,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out = f(inp)
|
||||
out2 = f(inp2)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
||||
self.assertArraysEqual(out2, np_inp @ np_inp.T)
|
||||
@ -1329,8 +1329,8 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
jtu.count_jit_and_pmap_lowerings() as compile_count):
|
||||
f(inp)
|
||||
f(inp2)
|
||||
self.assertEqual(cpp_count[0], 2)
|
||||
self.assertEqual(compile_count[0], 1)
|
||||
self.assertEqual(cpp_count(), 2)
|
||||
self.assertEqual(compile_count(), 1)
|
||||
|
||||
def test_jit_cpp_cache_output_hit(self):
|
||||
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")
|
||||
@ -1342,7 +1342,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out = mul_two(inp)
|
||||
mul_two(out)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_jit_cache_hit_with_default_and_specified_mem_kind(self):
|
||||
_, s, np_inp, _ = _create_inputs((8, 2), P("x", "y"))
|
||||
@ -1357,7 +1357,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
out = f(np_inp)
|
||||
out2 = g(np_inp2)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
self.assertArraysEqual(out, np_inp @ np_inp.T)
|
||||
self.assertArraysEqual(out2, np_inp2 @ np_inp2.T)
|
||||
@ -1633,7 +1633,7 @@ class ComputeOffload(jtu.BufferDonationTestCase):
|
||||
f(inp)
|
||||
|
||||
# 2 for `f` and `2` for `mul` (compute type changes for `mul`)
|
||||
self.assertEqual(count[0], 4)
|
||||
self.assertEqual(count(), 4)
|
||||
|
||||
def test_offload_take_host(self):
|
||||
@compute_on('device_host')
|
||||
|
@ -141,7 +141,7 @@ class MultiDeviceTest(jtu.JaxTestCase):
|
||||
y = lax.add(x, x)
|
||||
z = lax.add(y, y)
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assert_committed_to_device(y, devices[1])
|
||||
self.assert_committed_to_device(z, devices[1])
|
||||
|
||||
|
@ -149,14 +149,14 @@ class PgleTest(jtu.JaxTestCase):
|
||||
with config.pgle_profiling_runs(2), config.enable_pgle(True):
|
||||
# Run 1: Module should be compiled without FDO. Two modules are expected
|
||||
# One is the funtion f, the other one is multi slice module
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 2)
|
||||
self.assertEqual(cache_miss_count(), 2)
|
||||
|
||||
# Run 2: Second PGLE run. Profile should be empty.
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 2)
|
||||
self.assertEqual(cache_miss_count(), 2)
|
||||
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
|
||||
# One for before and one for after optimization.
|
||||
self.assertLen(fdo_profiles_before_pgle, 2)
|
||||
@ -165,9 +165,9 @@ class PgleTest(jtu.JaxTestCase):
|
||||
os.path.getsize(os.path.join(dump_dir, fdo_profiles_before_pgle[0])), 0)
|
||||
|
||||
# Run 3: The module should be recompiled with FDO profiles
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 2)
|
||||
self.assertEqual(cache_miss_count(), 2)
|
||||
fdo_profiles_after_pgle = self.get_fdo_profiles(dump_dir)
|
||||
# One for before and one for after optimization.
|
||||
self.assertLen(fdo_profiles_after_pgle, 4)
|
||||
@ -179,9 +179,9 @@ class PgleTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
# Run 4: Fast-path should be used after PGLE is done
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(f(x), expected)
|
||||
self.assertLess(cache_miss_count[0], 2)
|
||||
self.assertLess(cache_miss_count(), 2)
|
||||
|
||||
def testAutoPgleWithAot(self):
|
||||
@jax.jit
|
||||
@ -197,14 +197,14 @@ class PgleTest(jtu.JaxTestCase):
|
||||
|
||||
with config.pgle_profiling_runs(1), config.enable_pgle(True):
|
||||
# Run 1
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(compiled(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 0)
|
||||
self.assertEqual(cache_miss_count(), 0)
|
||||
|
||||
# Run 2
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
self.assertArraysEqual(compiled(x), expected)
|
||||
self.assertEqual(cache_miss_count[0], 0)
|
||||
self.assertEqual(cache_miss_count(), 0)
|
||||
|
||||
def testAutoPgleWithPersistentCache(self):
|
||||
its = 50
|
||||
@ -243,24 +243,24 @@ class PgleTest(jtu.JaxTestCase):
|
||||
cc.reset_cache()
|
||||
cc.set_cache_dir(cache_dir)
|
||||
# Run 1: Module should be compiled without FDO
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
f(x)
|
||||
self.assertGreater(cache_miss_count[0], 0)
|
||||
self.assertGreater(cache_miss_count(), 0)
|
||||
|
||||
# Non-pgle profiled version of module should be saved
|
||||
non_pgle_profiled_files = os.listdir(cache_dir)
|
||||
self.assertNotEmpty(non_pgle_profiled_files)
|
||||
|
||||
# Run 2: Compilation should not be called
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
f(x)
|
||||
self.assertGreater(cache_miss_count[0], 0)
|
||||
self.assertGreater(cache_miss_count(), 0)
|
||||
|
||||
fdo_profiles_before_pgle = self.get_fdo_profiles(dump_dir)
|
||||
# Run 3: Module should be compiled with FDO and stored to persistent cache
|
||||
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
|
||||
with jtu.count_jit_compilation_cache_miss() as cache_miss_count:
|
||||
f(x)
|
||||
self.assertGreater(cache_miss_count[0], 0)
|
||||
self.assertGreater(cache_miss_count(), 0)
|
||||
|
||||
# Check if FDO profile file of the biggest module is not empty
|
||||
fdo_profiles_after_pgle = [
|
||||
|
@ -660,7 +660,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
jax.grad(f)(x) # Warm up the cache.
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
jax.grad(f)(x)
|
||||
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
|
||||
self.assertEqual(count(), 0) # no cache miss i.e. cache hit
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def testEvalJaxpr(self):
|
||||
@ -2200,7 +2200,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out = f(a)
|
||||
_ = f(out)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_pjit_different_device_recompilation(self):
|
||||
if jax.device_count() < 2:
|
||||
@ -2217,7 +2217,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_jit_compilation_cache_miss() as count:
|
||||
out1 = f(a)
|
||||
out2 = f(b)
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
self.assertArraysEqual(out1, val1)
|
||||
self.assertArraysEqual(out2, val2)
|
||||
@ -2645,7 +2645,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
inp_data, jax.sharding.NamedSharding(mesh, P('x')))
|
||||
with mesh:
|
||||
f(arr1)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_single_device_add_single_compile(self):
|
||||
f1 = pjit(lambda x, y: x + y)
|
||||
@ -2657,7 +2657,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(2):
|
||||
f1(a, b)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_global_array_to_host_local_array_already_host_local(self):
|
||||
inp_shape = (8, 2)
|
||||
@ -2791,7 +2791,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
f_names = pjit(f, static_argnames='x')
|
||||
f_names(y, x='foo')
|
||||
f_names(y, x='foo')
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_new_static_argnum_on_keyword_arguments(self):
|
||||
f = pjit(lambda x: x, static_argnums=0)
|
||||
@ -2834,7 +2834,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# The count here is 0 because before `count_pjit_cpp_cache_miss`, `f` was
|
||||
# called with `system_default_device` and `test_device` so it was added
|
||||
# to the cache. Subsequent calls hit the C++ cache.
|
||||
self.assertEqual(count[0], 0)
|
||||
self.assertEqual(count(), 0)
|
||||
|
||||
def test_pjit_with_mismatched_static_argnames(self):
|
||||
x_is_tracer, y_is_tracer = False, False
|
||||
@ -3287,7 +3287,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pjit(f)(inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_pjit_no_global_cache_hit_axis_resources(self):
|
||||
mesh = jtu.create_mesh((1,), ('x',))
|
||||
@ -3297,20 +3297,20 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(inp)
|
||||
self.assertEqual(count[0], 10)
|
||||
self.assertEqual(count(), 10)
|
||||
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
|
||||
self.assertEqual(count[0], 10)
|
||||
self.assertEqual(count(), 10)
|
||||
|
||||
pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pf(inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning,
|
||||
message="backend and device argument"):
|
||||
@ -3318,7 +3318,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
for _ in range(10):
|
||||
pf1(inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_with_sharding_constraint_spmd_axis_name(self):
|
||||
mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
||||
@ -3484,7 +3484,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out = f(arr)
|
||||
self.assertIsInstance(out.sharding, NamedSharding)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out2 = f2(arr)
|
||||
@ -3493,7 +3493,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out2 = f2(arr)
|
||||
self.assertIsInstance(out2.sharding, PositionalSharding)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
@ -3530,7 +3530,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out2.sharding, NamedSharding)
|
||||
|
||||
# Drops out of C++ cache i.e. cache miss
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
# Still gets a hit on pjit_lower cache.
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
@ -3600,7 +3600,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
out3 = vf(out2)
|
||||
self.assertIsInstance(out3.sharding, NamedSharding)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_jit_mul_sum_sharding_preserved(self):
|
||||
if config.use_shardy_partitioner.value:
|
||||
@ -3625,7 +3625,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# This will hit the cpp cache.
|
||||
out3 = f(out2)
|
||||
self.assertIsInstance(out3.sharding, PositionalSharding)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
|
||||
self.assertEqual(cache_info2.misses, cache_info1.misses)
|
||||
@ -3755,7 +3755,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
top(jnp.arange(8))
|
||||
|
||||
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_wsc_eager(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
@ -3792,7 +3792,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
out1 = core.jaxpr_as_fun(jaxpr)(inp)
|
||||
out2 = core.jaxpr_as_fun(jaxpr)(inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assertArraysEqual(out1[0], inp * 2)
|
||||
self.assertArraysEqual(out2[0], inp * 2)
|
||||
|
||||
@ -3970,7 +3970,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_lowering_cache_miss_different_devices_and_sharding(self):
|
||||
if jax.device_count() < 4:
|
||||
@ -3993,7 +3993,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
g(np.arange(8))
|
||||
self.assertEqual(count[0], 2)
|
||||
self.assertEqual(count(), 2)
|
||||
|
||||
def test_single_device_named_sharding_preserved(self):
|
||||
mesh = jax.sharding.Mesh([jax.devices()[0]], 'x')
|
||||
@ -4021,7 +4021,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# that is expected.
|
||||
with jtu.count_device_put_fast_path_hit() as count:
|
||||
out = jax.device_put(arr1, NamedSharding(mesh2, P('x')))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assertTupleEqual(out.sharding._device_assignment,
|
||||
mesh2._flat_devices_tuple)
|
||||
self.assertArraysEqual(out, inp)
|
||||
@ -4536,9 +4536,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
# same num_devices but different devices.
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # tracing and lowering cache *hit*
|
||||
self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
self.assertEqual(lowering_count[0], 1)
|
||||
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
|
||||
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
|
||||
|
||||
def test_wsc_abstract_mesh(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
@ -4621,7 +4621,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
jax.jit(f, out_shardings=s)(np.arange(8))
|
||||
jax.jit(f, out_shardings=s)(np.arange(8))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_input_shardings_single_device(self):
|
||||
@jax.jit
|
||||
|
@ -1287,7 +1287,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x = jnp.arange(device_count)
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): fix this
|
||||
# self.assertEqual(count(), 0) # TODO(mattjj): fix this
|
||||
expected = np.repeat(3, device_count)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -1308,7 +1308,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x = jnp.arange(len(devices))
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
|
||||
expected = np.repeat(3, len(devices))
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -1344,7 +1344,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
|
||||
expected = 3 * np.ones(shape[:2])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -1370,7 +1370,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
# self.assertEqual(count(), 0) # TODO(mattjj): don't compile for constants
|
||||
expected = 3 * np.ones(shape[:2])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
|
||||
@ -2043,7 +2043,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
_, f_bwd2 = jax.vjp(f, x)
|
||||
_ = f_bwd(x)
|
||||
_ = f_bwd2(x)
|
||||
self.assertEqual(count[0], 0) # cache hits on fwd and bwd
|
||||
self.assertEqual(count(), 0) # cache hits on fwd and bwd
|
||||
|
||||
def testSizeOverflow(self):
|
||||
if config.disable_jit.value:
|
||||
|
@ -588,7 +588,7 @@ class PythonCallbackTest(jtu.JaxTestCase):
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
for _ in range(3):
|
||||
self.assertAllClose(2 * x, fun(x))
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
|
||||
class PureCallbackTest(jtu.JaxTestCase):
|
||||
|
@ -1041,7 +1041,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
key = self.make_key(1).block_until_ready()
|
||||
with jtu.count_device_put() as count:
|
||||
jax.jit(random.split)(key)
|
||||
self.assertLessEqual(count[0], 1) # 1 for the argument device_put
|
||||
self.assertLessEqual(count(), 1) # 1 for the argument device_put
|
||||
|
||||
@jtu.sample_product(dtype=int_dtypes + uint_dtypes)
|
||||
def test_randint_bounds(self, dtype):
|
||||
|
@ -709,7 +709,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
# TODO(jakevdp) remove this decorator when reuse checks move to C++
|
||||
@jax.debug_key_reuse(False)
|
||||
@ -726,7 +726,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_cpp_dispatch_aot_normal(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling an
|
||||
@ -739,7 +739,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
def test_cpp_dispatch_aot_split(self):
|
||||
# Ensure we stay on the C++ dispatch path when calling an
|
||||
@ -753,7 +753,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
f(key).block_until_ready()
|
||||
f(key).block_until_ready()
|
||||
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
|
||||
# -- prng primitives
|
||||
|
||||
|
@ -201,7 +201,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
f(np_inp)
|
||||
out1, out2 = f(np_inp)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assertTrue(s.is_equivalent_to(out1.sharding, np_inp.ndim))
|
||||
self.assertTrue(s.is_equivalent_to(out2.sharding, np_inp.ndim))
|
||||
|
||||
@ -213,7 +213,7 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
with jtu.count_pjit_cpp_cache_miss() as count:
|
||||
g(arr)
|
||||
out3, out4 = g(arr)
|
||||
self.assertEqual(count[0], 1)
|
||||
self.assertEqual(count(), 1)
|
||||
self.assertEqual(out3.sharding, s)
|
||||
self.assertEqual(out4.sharding, s)
|
||||
|
||||
|
@ -825,9 +825,9 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
b = jax.device_put(out_a, NamedSharding(mesh2, P()))
|
||||
f(b) # tracing and lowering cache *hit*
|
||||
|
||||
self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
self.assertEqual(lowering_count[0], 1)
|
||||
self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ.
|
||||
self.assertEqual(tracing_count(), 2) # 1 miss for `f` and 1 miss for `sin`
|
||||
self.assertEqual(lowering_count(), 1)
|
||||
self.assertEqual(compilation_count(), 2) # 2 misses since devices differ.
|
||||
|
||||
def test_shmap_abstract_mesh_errors(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user