mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add some missing jax.numpy documentation
This commit is contained in:
parent
5a41093970
commit
adf1492843
@ -651,7 +651,8 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy
|
||||
def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
|
||||
"""Returns the type to which a binary operation should cast its arguments.
|
||||
|
||||
For details of JAX's type promotion semantics, see :ref:`type-promotion`.
|
||||
JAX implementation of :func:`numpy.promote_types`. For details of JAX's
|
||||
type promotion semantics, see :ref:`type-promotion`.
|
||||
|
||||
Args:
|
||||
a: a :class:`numpy.dtype` or a dtype specifier.
|
||||
@ -659,6 +660,35 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType:
|
||||
|
||||
Returns:
|
||||
A :class:`numpy.dtype` object.
|
||||
|
||||
Examples:
|
||||
Type specifiers may be strings, dtypes, or scalar types, and the return
|
||||
value is always a dtype:
|
||||
|
||||
>>> jnp.promote_types('int32', 'float32') # strings
|
||||
dtype('float32')
|
||||
>>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes
|
||||
dtype('float32')
|
||||
>>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types
|
||||
dtype('float32')
|
||||
|
||||
Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are
|
||||
treated as weakly-typed and will not change the bit width of a strongly-typed
|
||||
counterpart (see discussion in :ref:`type-promotion`):
|
||||
|
||||
>>> jnp.promote_types('uint8', int)
|
||||
dtype('uint8')
|
||||
>>> jnp.promote_types('float16', float)
|
||||
dtype('float16')
|
||||
|
||||
This differs from the NumPy version of this function, which treats built-in scalar
|
||||
types as equivalent to 64-bit types:
|
||||
|
||||
>>> import numpy
|
||||
>>> numpy.promote_types('uint8', int)
|
||||
dtype('int64')
|
||||
>>> numpy.promote_types('float16', float)
|
||||
dtype('float64')
|
||||
"""
|
||||
# Note: we deliberately avoid `if a in _weak_types` here because we want to check
|
||||
# object identity, not object equality, due to the behavior of np.dtype.__eq__
|
||||
|
@ -194,6 +194,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta:
|
||||
meta = _ScalarMeta(np_scalar_type.__name__, (object,),
|
||||
{"dtype": np.dtype(np_scalar_type)})
|
||||
meta.__module__ = _PUBLIC_MODULE_NAME
|
||||
meta.__doc__ =\
|
||||
f"""A JAX scalar constructor of type {np_scalar_type.__name__}.
|
||||
|
||||
While NumPy defines scalar types for each data type, JAX represents
|
||||
scalars as zero-dimensional arrays.
|
||||
"""
|
||||
return meta
|
||||
|
||||
bool_ = _make_scalar_type(np.bool_)
|
||||
|
@ -598,5 +598,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int,
|
||||
|
||||
Returns:
|
||||
wrapped : jax.numpy.ufunc wrapper of func.
|
||||
|
||||
Examples:
|
||||
Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`:
|
||||
|
||||
>>> import operator
|
||||
>>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0)
|
||||
|
||||
Now all the standard :class:`jax.numpy.ufunc` methods are available:
|
||||
|
||||
>>> x = jnp.arange(4)
|
||||
>>> add(x, 10)
|
||||
Array([10, 11, 12, 13], dtype=int32)
|
||||
>>> add.outer(x, x)
|
||||
Array([[0, 1, 2, 3],
|
||||
[1, 2, 3, 4],
|
||||
[2, 3, 4, 5],
|
||||
[3, 4, 5, 6]], dtype=int32)
|
||||
>>> add.reduce(x)
|
||||
Array(6, dtype=int32)
|
||||
>>> add.accumulate(x)
|
||||
Array([0, 1, 3, 6], dtype=int32)
|
||||
>>> add.at(x, 1, 10, inplace=False)
|
||||
Array([ 0, 11, 2, 3], dtype=int32)
|
||||
"""
|
||||
return ufunc(func, nin, nout, identity=identity)
|
||||
|
Loading…
x
Reference in New Issue
Block a user