mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Avoid references to symbols removed in numpy 2.0
This commit is contained in:
parent
8a56a202e3
commit
505f03b40f
@ -222,16 +222,16 @@ def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
|
||||
if obj is None:
|
||||
obj = dtypes.float_
|
||||
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
|
||||
obj = _DEFAULT_TYPEMAP[np.dtype(obj, align=align, copy=copy).type]
|
||||
obj = _DEFAULT_TYPEMAP[obj]
|
||||
return np.dtype(obj, align=align, copy=copy)
|
||||
|
||||
### utility functions
|
||||
|
||||
_DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = {
|
||||
np.bool_: bool_,
|
||||
np.int_: int_,
|
||||
np.float_: float_,
|
||||
np.complex_: complex_
|
||||
bool: bool_,
|
||||
int: int_,
|
||||
float: float_,
|
||||
complex: complex_,
|
||||
}
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
|
Loading…
x
Reference in New Issue
Block a user