Document ShapeDtypeStruct

This commit is contained in:
Jake VanderPlas 2023-03-21 13:53:20 -07:00
parent b5c9c0f47e
commit 4a9ed3eaa8
2 changed files with 20 additions and 22 deletions

View File

@ -69,6 +69,7 @@ Just-in-time compilation (:code:`jit`)
xla_computation
make_jaxpr
eval_shape
ShapeDtypeStruct
device_put
device_put_replicated
device_put_sharded

View File

@ -2704,6 +2704,16 @@ def device_get(x: Any):
class ShapeDtypeStruct:
"""A container for the shape, dtype, and other static attributes of an array.
``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`.
Args:
shape: a sequence of integers representing an array shape
dtype: a dtype-like object
named_shape: (optional) a dictionary representing a named shape
sharding: (optional) a :class:`jax.Sharding` object
"""
__slots__ = ["shape", "dtype", "named_shape", "sharding"]
def __init__(self, shape, dtype, named_shape=None, sharding=None):
self.shape = tuple(shape)
@ -2764,20 +2774,9 @@ def eval_shape(fun: Callable, *args, **kwargs):
def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.tree_util.tree_map(shape_dtype_struct, out)
def shape_dtype_struct(x):
return ShapeDtypeStruct(x.shape, x.dtype)
class ShapeDtypeStruct:
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
In particular, the output is a pytree of objects that have ``shape`` and
``dtype`` attributes, but nothing else about them is guaranteed by the API.
But instead of applying ``fun`` directly, which might be expensive, it uses
JAX's abstract interpretation machinery to evaluate the shapes without doing
any FLOPs.
@ -2790,26 +2789,24 @@ def eval_shape(fun: Callable, *args, **kwargs):
*args: a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the ``shape`` and ``dtype`` attributes are
accessed, only values that duck-type arrays are required, rather than real
ndarrays. The duck-typed objects cannot be namedtuples because those are
treated as standard Python containers. See the example below.
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
that duck-types as ndarrays (note however that duck-typed objects cannot
be namedtuples because those are treated as standard Python containers).
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in ``args``, array values
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
Returns:
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> class MyArgArray(object):
... def __init__(self, shape, dtype):
... self.shape = shape
... self.dtype = jnp.dtype(dtype)
...
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
>>> print(out.shape)
(2000, 1000)