mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Roll back #23404, because it incorrectly casts numpy scalars to float
when dtype = None
For example: ``` >>> dtypes.coerce_to_array(np.complex64(1+1j)) jax/_src/dtypes.py:323: ComplexWarning: Casting complex values to real discards the imaginary part return np.array(x).astype(dtype) array(1.) ``` Reverts 3672b633c30fe82ef94d6cb83889894bdda64295 PiperOrigin-RevId: 671439898
This commit is contained in:
parent
dba674153e
commit
65b1b0bd95
@ -320,7 +320,7 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
|
||||
"""
|
||||
if dtype is None and type(x) in python_scalar_dtypes:
|
||||
dtype = _scalar_type_to_dtype(type(x), x)
|
||||
return np.array(x).astype(dtype)
|
||||
return np.asarray(x, dtype)
|
||||
|
||||
iinfo = ml_dtypes.iinfo
|
||||
finfo = ml_dtypes.finfo
|
||||
|
Loading…
x
Reference in New Issue
Block a user