mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18000 from alhridoy:arange-precision-warning
PiperOrigin-RevId: 575328461
This commit is contained in:
commit
73a973eaa8
@ -2317,7 +2317,16 @@ def identity(n: DimSize, dtype: DTypeLike | None = None) -> Array:
|
||||
return eye(n, dtype=dtype)
|
||||
|
||||
|
||||
@util._wraps(np.arange)
|
||||
@util._wraps(np.arange,lax_description= """
|
||||
.. note::
|
||||
|
||||
Using ``arange`` with the ``step`` argument can lead to precision errors,
|
||||
especially with lower-precision data types like ``fp8`` and ``bf16``.
|
||||
For more details, see the docstring of :func:`numpy.arange`.
|
||||
To avoid precision errors, consider using an expression like
|
||||
``(jnp.arange(-600, 600) * .01).astype(jnp.bfloat16)`` to generate a sequence in a higher precision
|
||||
and then convert it to the desired lower precision.
|
||||
""")
|
||||
def arange(start: DimSize, stop: Optional[DimSize] = None,
|
||||
step: DimSize | None = None, dtype: DTypeLike | None = None) -> Array:
|
||||
dtypes.check_user_dtype_supported(dtype, "arange")
|
||||
|
Loading…
x
Reference in New Issue
Block a user