Define jax.typing.DTypeLike

This commit is contained in:
Jake VanderPlas 2023-10-10 08:46:36 -07:00
parent f9190006a6
commit 117f4bdf9b
3 changed files with 12 additions and 2 deletions

View File

@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list
# jax 0.4.19
* New Features
* Added {obj}`jax.typing.DTypeLike`, which can be used to annotate objects that
are convertible to JAX dtypes.
* Changes
* JAX now requires SciPy 1.9 or newer.

View File

@ -11,3 +11,4 @@ List of Members
:toctree: _autosummary
ArrayLike
DTypeLike

View File

@ -21,10 +21,14 @@ The currently-available types are:
- :class:`jax.Array`: annotation for any JAX array or tracer (i.e. representations of arrays
within JAX transforms).
- :class:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to
- :obj:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to
a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python
builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values
(e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.)
- :obj:`jax.typing.DTypeLike`: annotation for any value that can be cast to a JAX-compatible
dtype; this includes strings (e.g. `'float32'`, `'int32'`), scalar types (e.g. `float`,
`np.float32`), dtypes (e.g. `np.dtype('float32')`), or objects with a dtype attribute
(e.g. `jnp.float32`, `jnp.int32`).
We may add additional types here in future releases.
@ -66,5 +70,6 @@ see `Non-array inputs NumPy vs JAX`_
.. _Non-array inputs NumPy vs JAX: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#non-array-inputs-numpy-vs-jax
"""
from jax._src.typing import (
ArrayLike as ArrayLike
ArrayLike as ArrayLike,
DTypeLike as DTypeLike,
)