diff --git a/CHANGELOG.md b/CHANGELOG.md index 29110aca7..6c980dec7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list # jax 0.4.19 +* New Features + * Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that + are convertible to JAX dtypes. + * Changes * JAX now requires SciPy 1.9 or newer. diff --git a/docs/jax.typing.rst b/docs/jax.typing.rst index f3d76b37e..0d7c05da0 100644 --- a/docs/jax.typing.rst +++ b/docs/jax.typing.rst @@ -11,3 +11,4 @@ List of Members :toctree: _autosummary ArrayLike + DTypeLike diff --git a/jax/typing.py b/jax/typing.py index 816c2ee87..c75e0567e 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -21,10 +21,14 @@ The currently-available types are: - :class:`jax.Array`: annotation for any JAX array or tracer (i.e. representations of arrays within JAX transforms). -- :class:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to +- :obj:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values (e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.) +- :obj:`jax.typing.DTypeLike`: annotation for any value that can be cast to a JAX-compatible + dtype; this includes strings (e.g. `'float32'`, `'int32'`), scalar types (e.g. `float`, + `np.float32`), dtypes (e.g. `np.dtype('float32')`), or objects with a dtype attribute + (e.g. `jnp.float32`, `jnp.int32`). We may add additional types here in future releases. @@ -66,5 +70,6 @@ see `Non-array inputs NumPy vs JAX`_ .. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax """ from jax._src.typing import ( - ArrayLike as ArrayLike + ArrayLike as ArrayLike, + DTypeLike as DTypeLike, )