Fix some numpy 2.0 incompatibilities

This commit is contained in:
Jake VanderPlas 2023-09-21 10:24:52 -07:00
parent 9ca133b5da
commit 4edb74ba7b
5 changed files with 5 additions and 5 deletions

View File

@ -3683,7 +3683,7 @@ def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
raise NotImplementedError # loop and stack
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
input_shape = np.array(primals[0].shape, dtype=np.int_)
input_shape = np.array(primals[0].shape, dtype=int)
n = np.prod(input_shape[list(axes)])
non_axes = np.delete(np.arange(len(input_shape)), axes)

View File

@ -2729,7 +2729,7 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
if dtype is None:
dtype = _dtype(a)
if issubdtype(dtype, integer):
default_int = dtypes.canonicalize_dtype(np.int_)
default_int = dtypes.canonicalize_dtype(int)
if iinfo(dtype).bits < iinfo(default_int).bits:
dtype = default_int

View File

@ -514,7 +514,7 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
keepdims: bool = False) -> Array:
check_arraylike("count_nonzero", a)
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims)
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],

View File

@ -404,7 +404,7 @@ class TestPromotionTables(jtu.JaxTestCase):
self.assertEqual(result1.aval, result2.aval)
def testResultTypeNone(self):
# This matches the behavior of np.result_type(None) => np.float_
# This matches the behavior of np.result_type(None) => np.float64
self.assertEqual(dtypes.result_type(None), dtypes.canonicalize_dtype(dtypes.float_))
def testResultTypeWeakFlag(self):

View File

@ -29,7 +29,7 @@ config.parse_flags_with_absl()
def random_inputs(rng, input_shape):
if type(input_shape) is tuple:
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_))
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(float))
elif type(input_shape) is list:
return [random_inputs(rng, shape) for shape in input_shape]
else: