diff --git a/jax/typing.py b/jax/typing.py index c75e0567e..89efa1f2c 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -24,7 +24,7 @@ The currently-available types are: - :obj:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values - (e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.) + (e.g. :class:`numpy.int32`, :class:`numpy.float64`, etc.) - :obj:`jax.typing.DTypeLike`: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. `'float32'`, `'int32'`), scalar types (e.g. `float`, `np.float32`), dtypes (e.g. `np.dtype('float32')`), or objects with a dtype attribute