Add precision warning and workaround to jnp.arange documentation

This commit is contained in:
alhridoy 2023-10-20 15:01:24 -06:00
parent d631fb10fd
commit 63f7cfe04c

View File

@ -2317,7 +2317,16 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
return eye(n, dtype=dtype)
@util._wraps(np.arange)
@util._wraps(np.arange,lax_description= """
.. note::
Using ``arange`` with the ``step`` argument can lead to precision errors,
especially with lower-precision data types like ``fp8`` and ``bf16``.
For more details, see the docstring of :func:`numpy.arange`.
To avoid precision errors, consider using an expression like
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
and then convert it to the desired lower precision.
""")
def arange(start: DimSize, stop: Optional[DimSize] = None,
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
dtypes.check_user_dtype_supported(dtype, "arange")