Allow unhashable callables in jax.eval_shape.

PiperOrigin-RevId: 599691923
This commit is contained in:
Matthew Johnson 2024-01-18 19:14:38 -08:00 committed by jax authors
parent 6d21b498c0
commit 0d4f200e08

View File

@ -2835,6 +2835,9 @@ def eval_shape(fun: Callable, *args, **kwargs):
>>> print(out.dtype)
float32
"""
# Workaround to support unhashable callables.
try: hash(fun)
except TypeError: fun = partial(fun)
# The traced_for name is `jit` so as to get maximum tracing cache hits.
# Eventually, we should deprecate `eval_shape` and expose it like AOT style.
f, dbg, res_paths, args_flat, _, out_tree, _, _ = pjit.get_wrapped_fun(