Add eval_shape example for function with static arguments

Improve wording and formating of dynamic eval_shape example
This commit is contained in:
Axel Donath 2023-12-17 10:44:43 -05:00 committed by Axel Donath
parent e0cc9879d5
commit e8330b5fc5

View File

@ -2814,6 +2814,24 @@ def eval_shape(fun: Callable, *args, **kwargs):
(2000, 1000)
>>> print(out.dtype)
float32
All arguments passed via :func:`eval_shape` will be treated as dynamic;
static arguments can be included via closure, for example using :func:`functools.partial`:
>>> import jax
>>> from jax import lax
>>> from functools import partial
>>> import jax.numpy as jnp
>>>
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
>>>
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
>>> out = jax.eval_shape(conv_same, x, kernel)
>>> print(out.shape)
(1, 32, 28, 28)
>>> print(out.dtype)
float32
"""
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)