1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Roll back , because it incorrectly casts numpy scalars to float when dtype = None

For example:
```
>>> dtypes.coerce_to_array(np.complex64(1+1j))
jax/_src/dtypes.py:323: ComplexWarning: Casting complex values to real discards the imaginary part
  return np.array(x).astype(dtype)
array(1.)
```

Reverts 3672b633c30fe82ef94d6cb83889894bdda64295

PiperOrigin-RevId: 671439898
This commit is contained in:
Jake VanderPlas 2024-09-05 11:13:02 -07:00 committed by jax authors
parent dba674153e
commit 65b1b0bd95

@ -320,7 +320,7 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray:
"""
if dtype is None and type(x) in python_scalar_dtypes:
dtype = _scalar_type_to_dtype(type(x), x)
return np.array(x).astype(dtype)
return np.asarray(x, dtype)
iinfo = ml_dtypes.iinfo
finfo = ml_dtypes.finfo