Avoid references to symbols removed in numpy 2.0

This commit is contained in:
Jake VanderPlas 2023-09-19 11:50:21 -07:00
parent 8a56a202e3
commit 505f03b40f

View File

@ -222,16 +222,16 @@ def _jnp_dtype(obj: DTypeLike | None, *, align: bool = False,
if obj is None:
obj = dtypes.float_
elif isinstance(obj, type) and obj in dtypes.python_scalar_dtypes:
obj = _DEFAULT_TYPEMAP[np.dtype(obj, align=align, copy=copy).type]
obj = _DEFAULT_TYPEMAP[obj]
return np.dtype(obj, align=align, copy=copy)
### utility functions
_DEFAULT_TYPEMAP: dict[type, _ScalarMeta] = {
np.bool_: bool_,
np.int_: int_,
np.float_: float_,
np.complex_: complex_
bool: bool_,
int: int_,
float: float_,
complex: complex_,
}
_lax_const = lax_internal._const