diff --git a/docs/jax.rst b/docs/jax.rst index b3ec03a58..51ce45d44 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -69,6 +69,7 @@ Just-in-time compilation (:code:`jit`) xla_computation make_jaxpr eval_shape + ShapeDtypeStruct device_put device_put_replicated device_put_sharded diff --git a/jax/_src/api.py b/jax/_src/api.py index 8476fe6ac..c614d0914 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2704,6 +2704,16 @@ def device_get(x: Any): class ShapeDtypeStruct: + """A container for the shape, dtype, and other static attributes of an array. + + ``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`. + + Args: + shape: a sequence of integers representing an array shape + dtype: a dtype-like object + named_shape: (optional) a dictionary representing a named shape + sharding: (optional) a :class:`jax.Sharding` object + """ __slots__ = ["shape", "dtype", "named_shape", "sharding"] def __init__(self, shape, dtype, named_shape=None, sharding=None): self.shape = tuple(shape) @@ -2764,20 +2774,9 @@ def eval_shape(fun: Callable, *args, **kwargs): def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) + shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype) return jax.tree_util.tree_map(shape_dtype_struct, out) - def shape_dtype_struct(x): - return ShapeDtypeStruct(x.shape, x.dtype) - - class ShapeDtypeStruct: - __slots__ = ["shape", "dtype"] - def __init__(self, shape, dtype): - self.shape = shape - self.dtype = dtype - - In particular, the output is a pytree of objects that have ``shape`` and - ``dtype`` attributes, but nothing else about them is guaranteed by the API. - But instead of applying ``fun`` directly, which might be expensive, it uses JAX's abstract interpretation machinery to evaluate the shapes without doing any FLOPs. @@ -2790,26 +2789,24 @@ def eval_shape(fun: Callable, *args, **kwargs): *args: a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the ``shape`` and ``dtype`` attributes are - accessed, only values that duck-type arrays are required, rather than real - ndarrays. The duck-typed objects cannot be namedtuples because those are - treated as standard Python containers. See the example below. + accessed, one can use :class:`jax.ShapeDtypeStruct` or another container + that duck-types as ndarrays (note however that duck-typed objects cannot + be namedtuples because those are treated as standard Python containers). **kwargs: a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in ``args``, array values need only be duck-typed to have ``shape`` and ``dtype`` attributes. + Returns: + out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves. + For example: >>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) - >>> class MyArgArray(object): - ... def __init__(self, shape, dtype): - ... self.shape = shape - ... self.dtype = jnp.dtype(dtype) - ... - >>> A = MyArgArray((2000, 3000), jnp.float32) - >>> x = MyArgArray((3000, 1000), jnp.float32) + >>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32) + >>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000)