mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Document ShapeDtypeStruct
This commit is contained in:
parent
b5c9c0f47e
commit
4a9ed3eaa8
@ -69,6 +69,7 @@ Just-in-time compilation (:code:`jit`)
|
||||
xla_computation
|
||||
make_jaxpr
|
||||
eval_shape
|
||||
ShapeDtypeStruct
|
||||
device_put
|
||||
device_put_replicated
|
||||
device_put_sharded
|
||||
|
@ -2704,6 +2704,16 @@ def device_get(x: Any):
|
||||
|
||||
|
||||
class ShapeDtypeStruct:
|
||||
"""A container for the shape, dtype, and other static attributes of an array.
|
||||
|
||||
``ShapeDtypeStruct`` is often used in conjunction with :func:`jax.eval_shape`.
|
||||
|
||||
Args:
|
||||
shape: a sequence of integers representing an array shape
|
||||
dtype: a dtype-like object
|
||||
named_shape: (optional) a dictionary representing a named shape
|
||||
sharding: (optional) a :class:`jax.Sharding` object
|
||||
"""
|
||||
__slots__ = ["shape", "dtype", "named_shape", "sharding"]
|
||||
def __init__(self, shape, dtype, named_shape=None, sharding=None):
|
||||
self.shape = tuple(shape)
|
||||
@ -2764,20 +2774,9 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
|
||||
def eval_shape(fun, *args, **kwargs):
|
||||
out = fun(*args, **kwargs)
|
||||
shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
return jax.tree_util.tree_map(shape_dtype_struct, out)
|
||||
|
||||
def shape_dtype_struct(x):
|
||||
return ShapeDtypeStruct(x.shape, x.dtype)
|
||||
|
||||
class ShapeDtypeStruct:
|
||||
__slots__ = ["shape", "dtype"]
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
In particular, the output is a pytree of objects that have ``shape`` and
|
||||
``dtype`` attributes, but nothing else about them is guaranteed by the API.
|
||||
|
||||
But instead of applying ``fun`` directly, which might be expensive, it uses
|
||||
JAX's abstract interpretation machinery to evaluate the shapes without doing
|
||||
any FLOPs.
|
||||
@ -2790,26 +2789,24 @@ def eval_shape(fun: Callable, *args, **kwargs):
|
||||
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
||||
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
||||
those types. Since only the ``shape`` and ``dtype`` attributes are
|
||||
accessed, only values that duck-type arrays are required, rather than real
|
||||
ndarrays. The duck-typed objects cannot be namedtuples because those are
|
||||
treated as standard Python containers. See the example below.
|
||||
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
||||
that duck-types as ndarrays (note however that duck-typed objects cannot
|
||||
be namedtuples because those are treated as standard Python containers).
|
||||
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
||||
Python containers (pytrees) of those types. As in ``args``, array values
|
||||
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
||||
|
||||
Returns:
|
||||
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>>
|
||||
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
|
||||
>>> class MyArgArray(object):
|
||||
... def __init__(self, shape, dtype):
|
||||
... self.shape = shape
|
||||
... self.dtype = jnp.dtype(dtype)
|
||||
...
|
||||
>>> A = MyArgArray((2000, 3000), jnp.float32)
|
||||
>>> x = MyArgArray((3000, 1000), jnp.float32)
|
||||
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
|
||||
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
|
||||
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
|
||||
>>> print(out.shape)
|
||||
(2000, 1000)
|
||||
|
Loading…
x
Reference in New Issue
Block a user