mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fix some numpy 2.0 incompatibilities
This commit is contained in:
parent
9ca133b5da
commit
4edb74ba7b
@ -3683,7 +3683,7 @@ def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr,
|
||||
raise NotImplementedError # loop and stack
|
||||
|
||||
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
||||
input_shape = np.array(primals[0].shape, dtype=np.int_)
|
||||
input_shape = np.array(primals[0].shape, dtype=int)
|
||||
|
||||
n = np.prod(input_shape[list(axes)])
|
||||
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
||||
|
@ -2729,7 +2729,7 @@ def trace(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1,
|
||||
if dtype is None:
|
||||
dtype = _dtype(a)
|
||||
if issubdtype(dtype, integer):
|
||||
default_int = dtypes.canonicalize_dtype(np.int_)
|
||||
default_int = dtypes.canonicalize_dtype(int)
|
||||
if iinfo(dtype).bits < iinfo(default_int).bits:
|
||||
dtype = default_int
|
||||
|
||||
|
@ -514,7 +514,7 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
|
||||
keepdims: bool = False) -> Array:
|
||||
check_arraylike("count_nonzero", a)
|
||||
return sum(lax.ne(a, _lax_const(a, 0)), axis=axis,
|
||||
dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
|
||||
dtype=dtypes.canonicalize_dtype(int), keepdims=keepdims)
|
||||
|
||||
|
||||
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
|
||||
|
@ -404,7 +404,7 @@ class TestPromotionTables(jtu.JaxTestCase):
|
||||
self.assertEqual(result1.aval, result2.aval)
|
||||
|
||||
def testResultTypeNone(self):
|
||||
# This matches the behavior of np.result_type(None) => np.float_
|
||||
# This matches the behavior of np.result_type(None) => np.float64
|
||||
self.assertEqual(dtypes.result_type(None), dtypes.canonicalize_dtype(dtypes.float_))
|
||||
|
||||
def testResultTypeWeakFlag(self):
|
||||
|
@ -29,7 +29,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
def random_inputs(rng, input_shape):
|
||||
if type(input_shape) is tuple:
|
||||
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_))
|
||||
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(float))
|
||||
elif type(input_shape) is list:
|
||||
return [random_inputs(rng, shape) for shape in input_shape]
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user