diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 82be38d1c..d2a55933c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -651,7 +651,8 @@ def _least_upper_bound(jax_numpy_dtype_promotion: str, *nodes: JAXType) -> JAXTy def promote_types(a: DTypeLike, b: DTypeLike) -> DType: """Returns the type to which a binary operation should cast its arguments. - For details of JAX's type promotion semantics, see :ref:`type-promotion`. + JAX implementation of :func:`numpy.promote_types`. For details of JAX's + type promotion semantics, see :ref:`type-promotion`. Args: a: a :class:`numpy.dtype` or a dtype specifier. @@ -659,6 +660,35 @@ def promote_types(a: DTypeLike, b: DTypeLike) -> DType: Returns: A :class:`numpy.dtype` object. + + Examples: + Type specifiers may be strings, dtypes, or scalar types, and the return + value is always a dtype: + + >>> jnp.promote_types('int32', 'float32') # strings + dtype('float32') + >>> jnp.promote_types(jnp.dtype('int32'), jnp.dtype('float32')) # dtypes + dtype('float32') + >>> jnp.promote_types(jnp.int32, jnp.float32) # scalar types + dtype('float32') + + Built-in scalar types (:type:`int`, :type:`float`, or :type:`complex`) are + treated as weakly-typed and will not change the bit width of a strongly-typed + counterpart (see discussion in :ref:`type-promotion`): + + >>> jnp.promote_types('uint8', int) + dtype('uint8') + >>> jnp.promote_types('float16', float) + dtype('float16') + + This differs from the NumPy version of this function, which treats built-in scalar + types as equivalent to 64-bit types: + + >>> import numpy + >>> numpy.promote_types('uint8', int) + dtype('int64') + >>> numpy.promote_types('float16', float) + dtype('float64') """ # Note: we deliberately avoid `if a in _weak_types` here because we want to check # object identity, not object equality, due to the behavior of np.dtype.__eq__ diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ffdeca84a..ee33be8a1 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -194,6 +194,12 @@ def _make_scalar_type(np_scalar_type: type) -> _ScalarMeta: meta = _ScalarMeta(np_scalar_type.__name__, (object,), {"dtype": np.dtype(np_scalar_type)}) meta.__module__ = _PUBLIC_MODULE_NAME + meta.__doc__ =\ + f"""A JAX scalar constructor of type {np_scalar_type.__name__}. + + While NumPy defines scalar types for each data type, JAX represents + scalars as zero-dimensional arrays. + """ return meta bool_ = _make_scalar_type(np.bool_) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 3473e8a74..27e2973b2 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -598,5 +598,28 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. + + Examples: + Here is an example of creating a ufunc similar to :obj:`jax.numpy.add`: + + >>> import operator + >>> add = frompyfunc(operator.add, nin=2, nout=1, identity=0) + + Now all the standard :class:`jax.numpy.ufunc` methods are available: + + >>> x = jnp.arange(4) + >>> add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + >>> add.outer(x, x) + Array([[0, 1, 2, 3], + [1, 2, 3, 4], + [2, 3, 4, 5], + [3, 4, 5, 6]], dtype=int32) + >>> add.reduce(x) + Array(6, dtype=int32) + >>> add.accumulate(x) + Array([0, 1, 3, 6], dtype=int32) + >>> add.at(x, 1, 10, inplace=False) + Array([ 0, 11, 2, 3], dtype=int32) """ return ufunc(func, nin, nout, identity=identity)