mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add eval_shape example for function with static arguments
Improve wording and formating of dynamic eval_shape example
This commit is contained in:
parent
e0cc9879d5
commit
e8330b5fc5
@ -2814,6 +2814,24 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
(2000, 1000)
|
||||
>>> print(out.dtype)
|
||||
float32
|
||||
|
||||
All arguments passed via :func:`eval_shape` will be treated as dynamic;
|
||||
static arguments can be included via closure, for example using :func:`functools.partial`:
|
||||
|
||||
>>> import jax
|
||||
>>> from jax import lax
|
||||
>>> from functools import partial
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
|
||||
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
|
||||
>>>
|
||||
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
|
||||
>>> out = jax.eval_shape(conv_same, x, kernel)
|
||||
>>> print(out.shape)
|
||||
(1, 32, 28, 28)
|
||||
>>> print(out.dtype)
|
||||
float32
|
||||
"""
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
|
Loading…
x
Reference in New Issue
Block a user