mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #19559 from jakevdp:key-reuse-shape-poly
PiperOrigin-RevId: 602503831
This commit is contained in:
commit
52b16867a5
@ -19,10 +19,10 @@ from functools import reduce
|
||||
from typing import Any, Callable, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
@ -195,9 +195,14 @@ def _slice_signature(eqn, args_consumed):
|
||||
limit_indices = eqn.params['limit_indices']
|
||||
strides = eqn.params['strides'] or (1,) * len(start_indices)
|
||||
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
|
||||
mask = np.zeros(in_aval.shape, dtype=bool)
|
||||
mask[idx] = True
|
||||
return KeyReuseSignatureWithForwards([Sink(0, mask)], [Source(0)])
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
sink = True
|
||||
else:
|
||||
# TODO(jakevdp): should we avoid constructing the mask array if the input
|
||||
# does not have a key dtype?
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
return KeyReuseSignatureWithForwards([Sink(0, sink)], [Source(0)])
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
||||
|
@ -19,10 +19,10 @@ from functools import reduce
|
||||
from typing import Any, Callable
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
@ -166,9 +166,14 @@ def _slice_signature(eqn, args_consumed):
|
||||
limit_indices = eqn.params['limit_indices']
|
||||
strides = eqn.params['strides'] or (1,) * len(start_indices)
|
||||
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
|
||||
mask = np.zeros(in_aval.shape, dtype=bool)
|
||||
mask[idx] = True
|
||||
return KeyReuseSignature([Sink(0, mask)], [Source(0)])
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
sink = True
|
||||
else:
|
||||
# TODO(jakevdp): should we avoid constructing the mask array if the input
|
||||
# does not have a key dtype?
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
return KeyReuseSignature([Sink(0, sink)], [Source(0)])
|
||||
|
||||
key_reuse_signatures_dynamic[lax.slice_p] = _slice_signature
|
||||
|
||||
|
@ -106,7 +106,7 @@ def _stop_profile(tst: jtu.JaxTestCase):
|
||||
p.sort_stats("cumtime").print_stats(.2)
|
||||
p.print_callers(.2)
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
|
||||
class DimExprTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1074,7 +1074,6 @@ def check_shape_poly(tst, f_jax: Callable, *,
|
||||
return h.run_test(tst)
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
class ShapePolyTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1475,6 +1474,7 @@ class ShapePolyTest(jtu.JaxTestCase):
|
||||
polymorphic_shapes=["(b,)"])
|
||||
self.assertAllClose(f(x), res_tf)
|
||||
|
||||
@jax.enable_key_reuse_checks(False)
|
||||
def test_prng(self):
|
||||
# The PRNG implementation uses opaque types, test shape polymorphism
|
||||
with config.enable_custom_prng(True):
|
||||
@ -2908,7 +2908,6 @@ def _flatten_harnesses(harnesses):
|
||||
return res
|
||||
|
||||
|
||||
@jtu.with_config(jax_enable_key_reuse_checks=False)
|
||||
class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
"""This test runs for all _POLY_SHAPE_PRIMITIVE_HARNESSES."""
|
||||
|
||||
@ -2987,16 +2986,19 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
if harness.group_name == "eig" and not jtu.test_device_matches(["cpu"]):
|
||||
raise unittest.SkipTest("JAX implements eig only on CPU.")
|
||||
|
||||
prev_jax_config_flags = {
|
||||
fname: getattr(jax.config, fname)
|
||||
for fname, fvalue in harness.override_jax_config_flags.items()
|
||||
}
|
||||
config_flags = harness.override_jax_config_flags
|
||||
# Update this here rather than in harness object because vmap_random_gamma is derived
|
||||
# from test_harnesses.all_harnesses, which strips override_jax_config_flags.
|
||||
if "random_gamma" in harness.group_name:
|
||||
config_flags = {**config_flags, "jax_enable_key_reuse_checks": False}
|
||||
|
||||
prev_jax_config_flags = {fname: getattr(jax.config, fname) for fname in config_flags}
|
||||
try:
|
||||
for fname, fvalue in harness.override_jax_config_flags.items():
|
||||
for fname, fvalue in config_flags.items():
|
||||
jax.config.update(fname, fvalue)
|
||||
harness.run_test(self)
|
||||
finally:
|
||||
for fname, _ in harness.override_jax_config_flags.items():
|
||||
for fname, _ in config_flags.items():
|
||||
jax.config.update(fname, prev_jax_config_flags[fname])
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user