mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow unhashable callables in jax.eval_shape.
PiperOrigin-RevId: 599691923
This commit is contained in:
parent
6d21b498c0
commit
0d4f200e08
@ -2835,6 +2835,9 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
>>> print(out.dtype)
|
||||
float32
|
||||
"""
|
||||
# Workaround to support unhashable callables.
|
||||
try: hash(fun)
|
||||
except TypeError: fun = partial(fun)
|
||||
# The traced_for name is `jit` so as to get maximum tracing cache hits.
|
||||
# Eventually, we should deprecate `eval_shape` and expose it like AOT style.
|
||||
f, dbg, res_paths, args_flat, _, out_tree, _, _ = pjit.get_wrapped_fun(
|
||||
|
Loading…
x
Reference in New Issue
Block a user