mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix checkify caching with nested call primitives
This commit is contained in:
parent
f445c84ba4
commit
c231171fb6
@ -351,9 +351,8 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
|
||||
# call_jaxpr handling
|
||||
call_jaxpr = params.pop('call_jaxpr')
|
||||
partial_checkify = lu.wrap_init(
|
||||
functools.partial(checkify_jaxpr_flat, call_jaxpr, (), enabled_errors,
|
||||
err_tree))
|
||||
partial_checkify = lu.hashable_partial(lu.wrap_init(
|
||||
checkify_jaxpr_flat), call_jaxpr, (), enabled_errors, err_tree)
|
||||
partial_checkify, metadata = _flatten_and_get_error_metadata_thunk(
|
||||
partial_checkify)
|
||||
|
||||
@ -688,6 +687,7 @@ error_checks[lax.scatter_max_p] = functools.partial(scatter_error_check,
|
||||
|
||||
# HOP error check rules
|
||||
|
||||
@weakref_lru_cache
|
||||
def jaxpr_to_checkify_jaxpr(
|
||||
jaxpr: core.ClosedJaxpr, enabled_errors, err_tree: PyTreeDef,
|
||||
*flat_err_and_in_vals) -> Tuple[core.ClosedJaxpr, PyTreeDef, Set[ErrorEffect]]:
|
||||
|
@ -775,7 +775,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def test_retracing(self):
|
||||
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
|
||||
_ = f(3.)
|
||||
with jtu.count_primitive_compiles() as count:
|
||||
with jtu.count_jit_and_pmap_compiles() as count:
|
||||
_ = f(3.)
|
||||
self.assertEqual(count[0], 0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user