Merge pull request #19559 from jakevdp:key-reuse-shape-poly

PiperOrigin-RevId: 602503831
This commit is contained in:
jax authors 2024-01-29 14:38:54 -08:00
commit 52b16867a5
3 changed files with 29 additions and 17 deletions

View File

@ -19,10 +19,10 @@ from functools import reduce
from typing import Any, Callable, NamedTuple
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
@ -195,9 +195,14 @@ def _slice_signature(eqn, args_consumed):
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
mask = np.zeros(in_aval.shape, dtype=bool)
mask[idx] = True
return KeyReuseSignatureWithForwards([Sink(0, mask)], [Source(0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)])
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

View File

@ -19,10 +19,10 @@ from functools import reduce
from typing import Any, Callable
import jax
from jax import core
from jax import lax
from jax import tree_util
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
@ -166,9 +166,14 @@ def _slice_signature(eqn, args_consumed):
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
mask = np.zeros(in_aval.shape, dtype=bool)
mask[idx] = True
return KeyReuseSignature([Sink(0, mask)], [Source(0)])
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
sink = True
else:
# TODO(jakevdp): should we avoid constructing the mask array if the input
# does not have a key dtype?
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignature([Sink(0, sink)], [Source(0)])
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature

View File

@ -106,7 +106,7 @@ def _stop_profile(tst: jtu.JaxTestCase):
p.sort_stats("cumtime").print_stats(.2)
p.print_callers(.2)
@jtu.with_config(jax_enable_key_reuse_checks=False)
class DimExprTest(jtu.JaxTestCase):
def setUp(self):
@ -1074,7 +1074,6 @@ def check_shape_poly(tst, f_jax: Callable, *,
return h.run_test(tst)
@jtu.with_config(jax_enable_key_reuse_checks=False)
class ShapePolyTest(jtu.JaxTestCase):
def setUp(self):
@ -1475,6 +1474,7 @@ class ShapePolyTest(jtu.JaxTestCase):
polymorphic_shapes=["(b,)"])
self.assertAllClose(f(x), res_tf)
@jax.enable_key_reuse_checks(False)
def test_prng(self):
# The PRNG implementation uses opaque types, test shape polymorphism
with config.enable_custom_prng(True):
@ -2908,7 +2908,6 @@ def _flatten_harnesses(harnesses):
return res
@jtu.with_config(jax_enable_key_reuse_checks=False)
class ShapePolyHarnessesTest(jtu.JaxTestCase):
"""This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""
@ -2987,16 +2986,19 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
raise unittest.SkipTest("JAX implements eig only on CPU.")
prev_jax_config_flags = {
fname: getattr(jax.config, fname)
for fname, fvalue in harness.override_jax_config_flags.items()
}
config_flags = harness.override_jax_config_flags
# Update this here rather than in harness object because vmap_random_gamma is derived
# from test_harnesses.all_harnesses, which strips override_jax_config_flags.
if "random_gamma" in harness.group_name:
config_flags = {**config_flags, "jax_enable_key_reuse_checks": False}
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
try:
for fname, fvalue in harness.override_jax_config_flags.items():
for fname, fvalue in config_flags.items():
jax.config.update(fname, fvalue)
harness.run_test(self)
finally:
for fname, _ in harness.override_jax_config_flags.items():
for fname, _ in config_flags.items():
jax.config.update(fname, prev_jax_config_flags[fname])
if __name__ == "__main__":