mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better docs for jnp.issubdtype & jnp.result_type
This commit is contained in:
parent
e2d3bd866a
commit
4495daee11
@ -255,7 +255,7 @@ array_repr = np.array_repr
|
||||
save = np.save
|
||||
savez = np.savez
|
||||
|
||||
@util.implements(np.dtype)
|
||||
|
||||
def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
|
||||
copy: bool = False) -> DType:
|
||||
"""Similar to np.dtype, but respects JAX dtype defaults."""
|
||||
@ -436,8 +436,50 @@ def fmax(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
"""
|
||||
return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2)
|
||||
|
||||
@util.implements(np.issubdtype)
|
||||
|
||||
def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
|
||||
"""Return True if arg1 is equal or lower than arg2 in the type hierarchy.
|
||||
|
||||
JAX implementation of :func:`numpy.issubdtype`.
|
||||
|
||||
The main difference in JAX's implementation is that it properly handles
|
||||
dtype extensions such as :code:`bfloat16`.
|
||||
|
||||
Args:
|
||||
arg1: dtype-like object. In typical usage, this will be a dtype specifier,
|
||||
such as ``"float32"`` (i.e. a string), ``np.dtype('int32')`` (i.e. an
|
||||
instance of :class:`numpy.dtype`), ``jnp.complex64`` (i.e. a JAX scalar
|
||||
constructor), or ``np.uint8`` (i.e. a NumPy scalar type).
|
||||
arg2: dtype-like object. In typical usage, this will be a generic scalar
|
||||
type, such as ``jnp.integer``, ``jnp.floating``, or ``jnp.complexfloating``.
|
||||
|
||||
Returns:
|
||||
True if arg1 represents a dtype that is equal or lower in the type
|
||||
hierarchy than arg2.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.isdtype`: similar function aligning with the array API standard.
|
||||
|
||||
Examples:
|
||||
>>> jnp.issubdtype('uint32', jnp.unsignedinteger)
|
||||
True
|
||||
>>> jnp.issubdtype(np.int32, jnp.integer)
|
||||
True
|
||||
>>> jnp.issubdtype(jnp.bfloat16, jnp.floating)
|
||||
True
|
||||
>>> jnp.issubdtype(np.dtype('complex64'), jnp.complexfloating)
|
||||
True
|
||||
>>> jnp.issubdtype('complex64', jnp.integer)
|
||||
False
|
||||
|
||||
Be aware that while this is very similar to :func:`numpy.issubdtype`, the
|
||||
results of these differ in the case of JAX's custom floating point types:
|
||||
|
||||
>>> np.issubdtype('bfloat16', np.floating)
|
||||
False
|
||||
>>> jnp.issubdtype('bfloat16', jnp.floating)
|
||||
True
|
||||
"""
|
||||
return dtypes.issubdtype(arg1, arg2)
|
||||
|
||||
@util.implements(np.isscalar)
|
||||
@ -448,8 +490,47 @@ def isscalar(element: Any) -> bool:
|
||||
|
||||
iterable = np.iterable
|
||||
|
||||
@util.implements(np.result_type)
|
||||
|
||||
def result_type(*args: Any) -> DType:
|
||||
"""Return the result of applying JAX promotion rules to the inputs.
|
||||
|
||||
JAX implementation of :func:`numpy.result_type`.
|
||||
|
||||
JAX's dtype promotion behavior is described in :ref:`type-promotion`.
|
||||
|
||||
Args:
|
||||
args: one or more arrays or dtype-like objects.
|
||||
|
||||
Returns:
|
||||
A :class:`numpy.dtype` instance representing the result of type
|
||||
promotion for the inputs.
|
||||
|
||||
Examples:
|
||||
Inputs can be dtype specifiers:
|
||||
|
||||
>>> jnp.result_type('int32', 'float32')
|
||||
dtype('float32')
|
||||
>>> jnp.result_type(np.uint16, np.dtype('int32'))
|
||||
dtype('int32')
|
||||
|
||||
Inputs may also be scalars or arrays:
|
||||
|
||||
>>> jnp.result_type(1.0, jnp.bfloat16(2))
|
||||
dtype(bfloat16)
|
||||
>>> jnp.result_type(jnp.arange(4), jnp.zeros(4))
|
||||
dtype('float32')
|
||||
|
||||
Be aware that the result type will be canonicalized based on the state
|
||||
of the ``jax_enable_x64`` configuration flag, meaning that 64-bit types
|
||||
may be downcast to 32-bit:
|
||||
|
||||
>>> jnp.result_type('float64')
|
||||
dtype('float32')
|
||||
|
||||
For details on 64-bit values, refer to `Sharp bits - double precision`_:
|
||||
|
||||
.. _Sharp bits - double precision: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
|
||||
"""
|
||||
return dtypes.result_type(*args)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user