diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index a65ce5268..b4a336f71 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -15,6 +15,7 @@ import enum import functools +import math import operator import google_benchmark @@ -54,7 +55,7 @@ def required_devices(num_devices_required): def create_mesh(shape, axis_names, state): - size = np.prod(shape) + size = math.prod(shape) if len(jax.devices()) < size: state.skip_with_error(f"Requires {size} devices") return None @@ -421,7 +422,7 @@ def sda_index_8(state): def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 - size = np.prod(shape) + size = math.prod(shape) rng = np.random.RandomState(1701) data = rng.randn(nse) indices = np.unravel_index( @@ -460,7 +461,7 @@ def sparse_bcoo_fromdense_compile(state): def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 - size = np.prod(shape) + size = math.prod(shape) rng = np.random.RandomState(1701) data = rng.randn(nse) indices = np.unravel_index( @@ -600,7 +601,7 @@ def bench_addressable_shards_index(state): if mesh is None: return shape = (8, 2) - inp = np.arange(np.prod(shape)).reshape(shape) + inp = np.arange(math.prod(shape)).reshape(shape) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) arr = jax.device_put(inp, s) @@ -614,7 +615,7 @@ def bench_addressable_shards_replica_id(state): if mesh is None: return shape = (64, 32) - inp = np.arange(np.prod(shape)).reshape(shape) + inp = np.arange(math.prod(shape)).reshape(shape) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) arr = jax.device_put(inp, s) @@ -785,7 +786,7 @@ def pjit_aot_4000_device(state): def host_local_array_to_global_array(state): global_mesh = create_mesh((4, 2), ('x', 'y'), state) input_shape = (8, 2) - input_data = np.arange(np.prod(input_shape)).reshape(input_shape) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) in_pspec = jax.sharding.PartitionSpec('x', 'y') while state: diff --git a/jax/_src/core.py b/jax/_src/core.py index d011dce09..129369bf4 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1869,8 +1869,8 @@ class DimensionHandler: Raise InconclusiveDimensionOperation if there is no such integer for all contexts, """ - sz1 = int(np.prod(s1)) - sz2 = int(np.prod(s2)) + sz1 = math.prod(s1) + sz2 = math.prod(s2) if sz1 == 0 and sz2 == 0: return 1 if sz1 % sz2: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 841346a9d..92564b1d9 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1624,10 +1624,10 @@ def manual_proto( tad_perm = ([axis_order[a] for a in replicated_axes] + [axis_order[a] for a in manual_axes]) tad_shape = [1] * aval.ndim - tad_shape.append(int(np.prod([named_mesh_shape[a] for a in replicated_axes], dtype=int))) - tad_shape.append(int(np.prod([named_mesh_shape[a] for a in manual_axes], dtype=int))) + tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes])) + tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes])) - raw_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape) + raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape) proto = xc.OpSharding() proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = tad_shape diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index a7cd4e1a9..d0c35ddfa 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1256,7 +1256,7 @@ mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') @partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)') def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array: m = lu.shape[0] - x = jnp.reshape(b, (m, np.prod(b.shape[1:]))) + x = jnp.reshape(b, (m, math.prod(b.shape[1:]))) if trans == 0: x = x[permutation, :] x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 22ef5bf65..fb232cc78 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -14,6 +14,7 @@ import enum from functools import partial +import math from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union import weakref @@ -1891,7 +1892,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, # on the IDs. ids_shape = np.array(updates.shape, dtype=np.int64) ids_shape[dnums.update_window_dims,] = 1 - num_ids = np.prod(ids_shape) + num_ids = math.prod(ids_shape) id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64 update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape), lax._ones(updates, dtype=id_dtype)) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ce8ce498f..81d2b3303 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -19,6 +19,7 @@ from collections import OrderedDict, abc from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set, NamedTuple, Union, Sequence, Mapping) from functools import wraps, partial, partialmethod, lru_cache +import math from jax import lax from jax import numpy as jnp @@ -1573,9 +1574,9 @@ def _get_axis_resource_count( if local_res_shape is None: nlocal = None else: - nlocal = int(np.prod(map(local_res_shape.get, resources), dtype=np.int64)) + nlocal = math.prod(map(local_res_shape.get, resources)) resource_count_map[axis] = ResourceCount( - int(np.prod(map(global_res_shape.get, resources), dtype=np.int64)), + math.prod(map(global_res_shape.get, resources)), nlocal, distributed) return resource_count_map diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 21894ffe1..51865130d 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -18,6 +18,7 @@ from __future__ import annotations import collections import contextlib import functools +import math import threading from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union @@ -199,7 +200,7 @@ class Mesh(contextlib.ContextDecorator): @property def size(self): - return np.prod(list(self.shape.values())) + return math.prod(self.shape.values()) @property def empty(self): diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index a7e10c70c..8a445b33e 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -163,15 +163,15 @@ def _compute_fans(shape: core.NamedShape, if isinstance(in_axis, int): in_size = shape[in_axis] else: - in_size = int(np.prod([shape[i] for i in in_axis])) + in_size = math.prod([shape[i] for i in in_axis]) if isinstance(out_axis, int): out_size = shape[out_axis] else: - out_size = int(np.prod([shape[i] for i in out_axis])) + out_size = math.prod([shape[i] for i in out_axis]) if isinstance(batch_axis, int): batch_size = shape[batch_axis] else: - batch_size = int(np.prod([shape[i] for i in batch_axis])) + batch_size = math.prod([shape[i] for i in batch_axis]) receptive_field_size = shape.total / in_size / out_size / batch_size fan_in = in_size * receptive_field_size fan_out = out_size * receptive_field_size diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 049b7a069..23761b323 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -747,7 +747,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]], dimensions[i], dimensions[s] = dimensions[s], dimensions[i] do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis) touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis) - a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions) + a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions) axis = _canonicalize_axis(-1, a.ndim) else: axis = _canonicalize_axis(axis, a.ndim) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index f47c28666..4ebc95422 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1043,7 +1043,7 @@ def iota_2x32_shape(shape): Setting aside representation, this function essentially computes the equivalent of:: - jax.lax.iota(dtype=np.uint64, size=np.prod(shape)).reshape(shape) + jax.lax.iota(dtype=np.uint64, size=math.prod(shape)).reshape(shape) However: @@ -1069,7 +1069,7 @@ def iota_2x32_shape(shape): [ 8, 9, 10, 11]], dtype=uint32)] >>> def reshaped_iota(shape): - ... return lax.iota(size=np.prod(shape), dtype=np.uint32).reshape(shape) + ... return lax.iota(size=math.prod(shape), dtype=np.uint32).reshape(shape) ... >>> reshaped_iota((3, 4)) Array([[ 0, 1, 2, 3], diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index 6e7829dbc..70b9f6084 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import math import operator from typing import Callable, Optional, Tuple, Union, Sequence import warnings @@ -232,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array], else: step = nperseg - noverlap batch_shape = list(batch_shape) - x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis] + x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis] result = jax.lax.conv_general_dilated_patches( x, (nperseg,), (step,), 'VALID', @@ -579,7 +580,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: raise ValueError('Input must have (..., frames, frame_length) shape.') *batch_shape, nframes, segment_len = x.shape - flat_batchsize = np.prod(batch_shape, dtype=np.int64) + flat_batchsize = math.prod(batch_shape) x = x.reshape((flat_batchsize, nframes, segment_len)) output_size = step_size * (nframes - 1) + segment_len nstep_per_segment = 1 + (segment_len - 1) // step_size diff --git a/jax/_src/sharding_specs.py b/jax/_src/sharding_specs.py index e4c76ff0f..4b352538c 100644 --- a/jax/_src/sharding_specs.py +++ b/jax/_src/sharding_specs.py @@ -75,7 +75,7 @@ def _sharding_spec_mesh_shape(self): def get_logical_mesh_ids(mesh_shape): - return np.arange(np.prod(mesh_shape)).reshape(mesh_shape) + return np.arange(math.prod(mesh_shape)).reshape(mesh_shape) _MeshAxisName = Any @@ -123,7 +123,7 @@ def sharding_spec_sharding_proto( assert mesh_shape[maxis] == nchunks mesh_permutation.append(maxis) next_sharded_axis += 1 - new_mesh_shape.append(int(np.prod(sharding.chunks))) + new_mesh_shape.append(math.prod(sharding.chunks)) elif isinstance(sharding, Unstacked): raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding") else: @@ -187,7 +187,7 @@ def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: axis_indices.append(range(axis_size)) shard_indices_shape.append(axis_size) elif isinstance(sharding, Chunked): - total_chunks = int(np.prod(sharding.chunks)) + total_chunks = math.prod(sharding.chunks) shard_size, ragged = divmod(axis_size, total_chunks) assert not ragged, (axis_size, total_chunks, dim) axis_indices.append([slice(i * shard_size, (i + 1) * shard_size) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 742fc7f6f..6be7cd86b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -487,7 +487,7 @@ def rand_fullrange(rng, standardize_nans=False): """Random numbers that span the full range of available bits.""" def gen(shape, dtype, post=lambda x: x): dtype = np.dtype(dtype) - size = dtype.itemsize * np.prod(_dims_of_shape(shape), dtype=int) + size = dtype.itemsize * math.prod(_dims_of_shape(shape)) vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8) vals = post(vals).view(dtype) if shape is PYTHON_SCALAR_SHAPE: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 8cb396339..e711b5e16 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -498,6 +498,7 @@ import atexit import functools import itertools import logging +import math import threading import traceback from typing import (Any, Callable, Dict, List, Optional, Sequence, @@ -1380,7 +1381,7 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict: def _aval_is_empty(aval) -> bool: - return np.prod(aval.shape) == 0 + return math.prod(aval.shape) == 0 def _instantiate_zeros(tan, arg): """Turn special ad.zero tangents into arrays of 0s for sending to host. diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index b0551ed5a..08e963d89 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -15,6 +15,7 @@ import builtins import dataclasses from functools import partial, wraps +import math import string from typing import Any, Callable, Dict, Optional, Sequence, Tuple @@ -665,7 +666,7 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding, # TODO(marcvanzee): This may give very large deviations on TPU when using # floats as inputs. Alternatively, we could implement this using a # convolution with an all-1's kernel. - return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions)) + return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions)) def _reduce_window(*args, jaxpr, consts, window_dimensions, @@ -951,7 +952,7 @@ def _gather_generate_indices(shape: Tuple[int, ...]): """ Returns the indices of the according to `shape`: each element in the output is the index of an element of an array - of the provided shape. The result's shape is (np.prod(shape), len(shape)) + of the provided shape. The result's shape is (math.prod(shape), len(shape)) For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1) """ diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 21440bb1a..ad5c46408 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -14,6 +14,7 @@ """Provides JAX and TensorFlow interoperation APIs.""" from functools import partial import contextlib +import math import operator import os import re @@ -3094,7 +3095,7 @@ def split_to_logical_devices(tensor: TfVal, # of _shard_values. if partition_dimensions is None: return xla_sharding.replicate(tensor, use_sharding_op=True) - num_partition_splits = np.prod(partition_dimensions) + num_partition_splits = math.prod(partition_dimensions) tile_assignment = np.arange(num_partition_splits).reshape( partition_dimensions) return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index aa2d93400..4f2c9d92b 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -304,7 +304,7 @@ class _DimMon(dict): assert a_l <= a_u bounds.append((a_l ** exp, a_u ** exp)) - candidates = [np.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] + candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)] return (min(*candidates), max(*candidates)) # type: ignore @@ -696,8 +696,8 @@ class DimensionHandlerPoly(core.DimensionHandler): return _ensure_poly(d1, "ge") >= d2 def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize: - sz1 = np.prod(s1) - sz2 = np.prod(s2) + sz1 = math.prod(s1) + sz2 = math.prod(s2) if core.symbolic_equal_dim(sz1, sz2): # Takes care also of sz1 == sz2 == 0 return 1 err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}" diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 59be029b7..af4118039 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -62,6 +62,7 @@ import dataclasses import datetime from functools import partial import itertools +import math import os import re import sys @@ -346,7 +347,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( @staticmethod def eigh_input(shape, dtype): # In order to keep inputs small, we construct the input programmatically - operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape) + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape) # Make operand self-adjoint operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2. return operand @@ -417,7 +418,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( @staticmethod def qr_harness(shape, dtype): # In order to keep inputs small, we construct the input programmatically - operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape) + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape) return lax.linalg.qr(operand, full_matrices=True) @parameterized.named_parameters( @@ -458,7 +459,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( @staticmethod def lu_harness(shape, dtype): - operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape) + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape) return lax.linalg.lu(operand) def test_tpu_Lu(self): diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 4c2682f0c..41a094f68 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -16,6 +16,7 @@ Specific JAX primitive conversion tests are in primitives_test.""" import collections import contextlib +import math import os import re from typing import Callable, Dict, Optional, Tuple @@ -1286,7 +1287,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase): return carry shape = (3, 2) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) jax_comp = jax.xla_computation(f_while)(x) backend = jax._src.xla_bridge.get_backend() @@ -1723,7 +1724,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase): return x + y shape = (8, 10) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) in_axis_resources = (P("x"), P("x")) out_axis_resources = None res_jax = pjit.pjit( diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 619dc1496..cf5687266 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the shape-polymorphic jax2tf conversion.""" import contextlib +import math import unittest from absl.testing import absltest, parameterized @@ -1280,7 +1281,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): # https://github.com/google/jax/issues/7093 # Also issue #6975. x_shape = (2, 3, 4) - xi = np.arange(np.prod(x_shape), dtype=np.int16).reshape(x_shape) + xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape) yf = xi.astype(np.float32) xi_yf = (xi, yf) zb = np.array([True, False], dtype=np.bool_) @@ -1350,7 +1351,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"]) f_tf = tf.function(f_tf, autograph=False) x_shape = (2, 3, 4) - x = np.arange(np.prod(x_shape), dtype=np.int32).reshape(x_shape) + x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape) # When saving the model with gradients, we trace the gradient function # and we used to get an error when creating zeros_like_aval for a @@ -1378,7 +1379,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase): jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)), polymorphic_shapes=["(b, 4)"])(np.ones((3, 4))) - jax2tf.convert(lambda x: jnp.reshape(x, (np.prod(x.shape),)), + jax2tf.convert(lambda x: jnp.reshape(x, (math.prod(x.shape),)), polymorphic_shapes=["(b, 4)"])(np.ones((3, 4))) jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]), diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index 693cf35fb..df8c8f3d5 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -628,7 +628,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash") mesh = Mesh(self.devices, axis_names=('x')) - a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4)) + a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4)) @partial(pjit.pjit, in_shardings=(P('x', None),), out_shardings=P(None, 'x')) @@ -680,7 +680,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): return lax.ppermute(b, 'x', perm=perm) with mesh: - a = np.arange(np.prod(4 * 4)).reshape((4, 4)) + a = np.arange(4 * 4).reshape((4, 4)) res_jax = f_jax(a) b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0 b0, b1 = b1, b0 @@ -705,7 +705,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase): if poly is not None: raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement") mesh = Mesh(self.devices, axis_names=('x')) - a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4)) + a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4)) @partial(pjit.pjit, in_shardings=(P('x', None),), out_shardings=P('x', None)) diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 573fe6d86..adf77d787 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -1575,7 +1575,7 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt n = current_n_dense - n_dense @partial(nfold_vmap, N=current_n_batch + 1) def _update(d, i): - new_d = d.reshape(np.prod(d.shape[:n]), *d.shape[n:]) + new_d = d.reshape(math.prod(d.shape[:n]), *d.shape[n:]) meshes = jnp.meshgrid(*(jnp.arange(s, dtype=i.dtype) for s in d.shape[:n]), indexing='ij') new_i = jnp.column_stack([jnp.broadcast_to(i, (new_d.shape[0], i.size)), @@ -1583,10 +1583,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt return new_d, new_i new_data, new_indices = _update(new_data, new_indices) new_data = new_data.reshape(*new_data.shape[:current_n_batch], - np.prod(new_data.shape[current_n_batch:current_n_batch + 2]), + math.prod(new_data.shape[current_n_batch:current_n_batch + 2]), *new_data.shape[current_n_batch + 2:]) new_indices = new_indices.reshape(*new_indices.shape[:current_n_batch], - np.prod(new_indices.shape[current_n_batch: current_n_batch + 2]), + math.prod(new_indices.shape[current_n_batch: current_n_batch + 2]), *new_indices.shape[current_n_batch + 2:]) current_n_dense = n_dense @@ -1595,10 +1595,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt @partial(nfold_vmap, N=n_batch) def _update(d, i): nse = i.shape[-2] - new_d = d.reshape(np.prod(d.shape[:n + 1]), *d.shape[n + 1:]) + new_d = d.reshape(math.prod(d.shape[:n + 1]), *d.shape[n + 1:]) meshes = jnp.meshgrid(*(jnp.arange(d, dtype=i.dtype) for d in (*i.shape[:n], nse)), indexing='ij') - new_i = i.reshape(np.prod(i.shape[:n + 1]), *i.shape[n + 1:]) + new_i = i.reshape(math.prod(i.shape[:n + 1]), *i.shape[n + 1:]) new_i = jnp.column_stack((*(m.ravel() for m in meshes[:-1]), new_i)) return new_d, new_i new_data, new_indices = _update(new_data, new_indices) @@ -1669,7 +1669,7 @@ def _bcoo_broadcast_in_dim(data: Array, indices: Array, *, spinfo: SparseInfo, s new_n_dense = props.n_dense and len(shape) - min(broadcast_dimensions[-props.n_dense:]) new_n_sparse = len(shape) - new_n_batch - new_n_dense - if np.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != np.prod(shape[new_n_batch:new_n_batch + new_n_sparse]): + if math.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != math.prod(shape[new_n_batch:new_n_batch + new_n_sparse]): raise NotImplementedError("Adding sparse dimensions with lengths != 1") new_data, new_indices = data, indices @@ -1784,8 +1784,8 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in batch_shape, sparse_shape, dense_shape = split_list(mat.shape, [mat.n_batch, mat.n_sparse]) batch_perm, sparse_perm, dense_perm = _validate_permutation( mat.data, mat.indices, dimensions or tuple(range(mat.ndim)), mat.shape) - batch_size = np.prod(batch_shape, dtype=int) - sparse_size = np.prod(sparse_shape, dtype=int) + batch_size = math.prod(batch_shape) + sparse_size = math.prod(sparse_shape) cuml_shape = np.cumprod(new_sizes) if batch_size != 1 and batch_size not in cuml_shape: @@ -1797,7 +1797,7 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in i1 = cuml_shape.searchsorted(batch_size, side='right') i2 = cuml_shape.searchsorted(batch_size * sparse_size, side='right') - new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [i1, i2]) + new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [int(i1), int(i2)]) # Reshape batch & dense dimensions: this is accomplished via a standard reshape. data = lax.reshape( @@ -1949,7 +1949,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense)) new_data = jnp.where(keep_data, new_data, 0) - new_nse = int(np.prod(new_shape_sparse)) + new_nse = math.prod(new_shape_sparse) if mat.nse > new_nse: new_data, new_indices = _bcoo_sum_duplicates( new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse) @@ -2029,8 +2029,8 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense)) new_data = jnp.where(keep_data, new_data, 0) - if mat.nse > np.prod(size_sparse): - new_nse = int(np.prod(size_sparse)) + if mat.nse > math.prod(size_sparse): + new_nse = math.prod(size_sparse) new_data, new_indices = _bcoo_sum_duplicates( new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse) @@ -2099,7 +2099,7 @@ def _bcoo_reduce_sum(data: Array, indices: Array, *, spinfo: SparseInfo, axes: S new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes)) new_batch_shape = tuple(data.shape[i] for i in new_batch_dims) - new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes])) + new_nse = nse * math.prod([data.shape[i] for i in batch_axes]) data = lax.reshape(data, (*new_batch_shape, new_nse, *data.shape[n_batch + 1:]), diff --git a/tests/array_test.py b/tests/array_test.py index a373198a9..e4717aac9 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -908,7 +908,7 @@ class ShardingTest(jtu.JaxTestCase): self.skipTest('Test needs >= 4 devices.') ps = jax.sharding.PmapSharding.default(shape, sharded_dim) - inp = jnp.arange(np.prod(shape)).reshape(shape) + inp = jnp.arange(math.prod(shape)).reshape(shape) compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile() pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index a9783f716..7be693a65 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -13,6 +13,7 @@ # limitations under the License. from absl.testing import absltest +import math import unittest import numpy as np @@ -206,7 +207,7 @@ mlir.register_lowering( def make_sparse_array(rng, shape, dtype, nnz=0.2): mat = rng(shape, dtype) - size = int(np.prod(shape)) + size = math.prod(shape) if 0 < nnz < 1: nnz = nnz * size nnz = int(nnz) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 58d55c240..3023e6164 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3805,16 +3805,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis): rng = jtu.rand_default(self.rng()) - i_shape = np.array(i_shape) + i_shape = list(i_shape) if axis is None: - i_shape = [np.prod(i_shape, dtype=np.int64)] + i_shape = [math.prod(i_shape)] else: # Test the case where the size of the axis doesn't necessarily broadcast. i_shape[axis] *= 3 - i_shape = list(i_shape) def args_maker(): x = rng(x_shape, dtype) - n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis] + n = math.prod(x_shape) if axis is None else x_shape[axis] if np.issubdtype(index_dtype, np.unsignedinteger): index_rng = jtu.rand_int(self.rng(), 0, n) else: diff --git a/tests/lax_test.py b/tests/lax_test.py index 25c6b68ea..2fbfdddf0 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -641,7 +641,7 @@ class LaxTest(jtu.JaxTestCase): if c == 'N': self.assertEqual(out_c, patch_c) elif c == 'C': - self.assertEqual(out_c * np.prod(filter_shape), patch_c) + self.assertEqual(out_c * math.prod(filter_shape), patch_c) else: self.assertEqual(out_c, patch_c * filter_shape[filter_spec.index(c)]) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 7becda639..7e31f5f70 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -226,8 +226,8 @@ class PJitTest(jtu.BufferDonationTestCase): x_shape = (8, 6, 4) y_shape = (4, 2) - x = jnp.arange(np.prod(x_shape)).reshape(x_shape) - y = jnp.arange(np.prod(y_shape)).reshape(y_shape) + x = jnp.arange(math.prod(x_shape)).reshape(x_shape) + y = jnp.arange(math.prod(y_shape)).reshape(y_shape) actual = f(x, y) expected = x @ y self.assertAllClose(actual, expected, check_dtypes=False) @@ -285,7 +285,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) actual = f(x, x + 1) expected = x @ (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) @@ -753,7 +753,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x + y + z + w - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) y = x * 2. z = x * 3. w = x * 4. @@ -822,7 +822,7 @@ class PJitTest(jtu.BufferDonationTestCase): token = lax.outfeed(token, x, partitions=(P(1, nr_devices),)) return x - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) def _dispatch(): with jax.sharding.Mesh(devices, ['d']): @@ -863,7 +863,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) expected = x @ (x + 1) lowered = f.lower(x, x + 1) @@ -896,7 +896,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) exe = f.lower(x, x + 1, a=1, b=2).compile() out = exe(x, x + 1, a=1, b=2) self.assertArraysEqual(out, x @ (x + 1)) @@ -910,7 +910,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) exe = f.lower(x, x + 1).compile() self.assertRaisesRegex( @@ -926,7 +926,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) x_f32 = x.astype(jnp.float32) x_i32 = x.astype(jnp.int32) exe = f.lower(x_f32, x_f32).compile() @@ -947,7 +947,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1) self.assertIsInstance(f.as_text(), str) self.assertIsInstance(f.as_text(dialect='hlo'), str) @@ -963,7 +963,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1) self.assertIsNotNone(f.compiler_ir()) self.assertIsNotNone(f.compiler_ir(dialect='hlo')) @@ -995,7 +995,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1).compile() self.assertIsNotNone(f.compiler_ir()) @@ -1008,7 +1008,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1).compile() self.assertIsInstance(f.as_text(), (str, type(None))) @@ -1022,7 +1022,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1) f.cost_analysis() # doesn't raise @@ -1036,7 +1036,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1).compile() f.cost_analysis() # doesn't raise @@ -1050,7 +1050,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1).compile() f.memory_analysis() # doesn't raise @@ -1063,7 +1063,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x @ y shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = f.lower(x, x + 1).compile() self.assertIsNotNone(f.runtime_executable()) @@ -1088,7 +1088,7 @@ class PJitTest(jtu.BufferDonationTestCase): shape = (8, 8) aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64)) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) exe = f.lower(aval, x).compile() self.assertIsInstance(exe, stages.Compiled) self.assertArraysEqual(exe(x, x), x @ x) @@ -2158,7 +2158,7 @@ class ArrayPjitTest(jtu.JaxTestCase): return x if c == 0 else x + 1 shape = (8, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) exe = f.lower(1, x).compile() self.assertAllClose(exe(x), x + 1, check_dtypes=False) @@ -3205,7 +3205,7 @@ class PJitErrorTest(jtu.JaxTestCase): def testNonDivisibleArgs(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources, None) - mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) + mesh_size = str(math.prod([dim[1] for dim in mesh])) error = re.compile( r"One of pjit arguments.*" + spec_regex(spec) + r".*" r"implies that the global size of its dimension 0 should be " @@ -3218,7 +3218,7 @@ class PJitErrorTest(jtu.JaxTestCase): def testNonDivisibleOuts(self, mesh, resources): x = jnp.ones((3, 2)) spec = P(resources, None) - mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64)) + mesh_size = str(math.prod([dim[1] for dim in mesh])) error = re.compile( r"One of pjit outputs.*" + spec_regex(spec) + r".*" r"implies that the global size of its dimension 0 should be " @@ -3456,7 +3456,7 @@ class PJitErrorTest(jtu.JaxTestCase): return x return h(x) xshape = (2, 5, 6) - x = jnp.arange(np.prod(xshape)).reshape(xshape) + x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex(RuntimeError, "Changing the physical mesh is not allowed.*"): f(x) @@ -3506,7 +3506,7 @@ class UtilTest(jtu.JaxTestCase): FakeDevice = namedtuple('FakeDevice', ['id']) mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)]) mesh_axes, mesh_shape = unzip2(mesh_named_shape.items()) - devices = [FakeDevice(i) for i in range(np.prod(list(mesh_shape)))] + devices = [FakeDevice(i) for i in range(math.prod(mesh_shape))] mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes)) dims = 5 diff --git a/tests/pmap_test.py b/tests/pmap_test.py index c7256b4fd..8bd77713c 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -581,7 +581,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testAllToAll(self, split_axis, concat_axis): pmap_in_axis = 0 shape = (jax.device_count(),) * 3 - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) @partial(self.pmap, axis_name='i') def f(x): @@ -602,7 +602,7 @@ class PythonPmapTest(jtu.JaxTestCase): raise SkipTest("test requires at least four devices") pmap_in_axis = 0 shape = (4, 4, 4) - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) @partial(self.pmap, axis_name='i') @partial(self.pmap, axis_name='j') @@ -2400,7 +2400,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase): verify_ref() shape = (jax.device_count(),) * 5 - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) self.assertAllClose(pmap(vmap(f, in_axes=vmap_axis), axis_name='i')(x), reference(x, split_axis, concat_axis, vmap_axis)) @@ -2413,7 +2413,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase): return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis) shape = (jax.device_count(),) * 4 - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) self.assertAllClose(pmap(f, axis_name='i')(x), vmap(f, axis_name='i')(x)) @@ -2431,7 +2431,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase): return lax.all_to_all(x, axes, split_axis=split_axis, concat_axis=concat_axis) shape = (2, 2, 4, 4, 4) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x), vmap(vmap(f, axis_name='j'), axis_name='i')(x)) @@ -2456,7 +2456,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase): if jax.device_count() < 4: raise SkipTest("test requires at least four devices") shape = (4, 4, 8) - x = jnp.arange(np.prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) f = partial(prim, axis_name='i', tiled=tiled) self.assertAllClose(vmap(f, axis_name='i')(x), pmap(f, axis_name='i')(x)) diff --git a/tests/random_test.py b/tests/random_test.py index 29f9749f6..e5615d471 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -14,6 +14,7 @@ from functools import partial +import math from unittest import SkipTest, skipIf from typing import Any, Tuple, NamedTuple, Optional import zlib @@ -689,7 +690,7 @@ class LaxRandomTest(jtu.JaxTestCase): for ndim in [1 if is_range else len(input_range_or_shape)] for axis in range(-ndim, ndim or 1) for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]] - if replace or np.prod(shape) <= ninputs + if replace or math.prod(shape) <= ninputs ], dtype=jtu.dtypes.floating + jtu.dtypes.integer, weighted=[True, False], @@ -702,7 +703,7 @@ class LaxRandomTest(jtu.JaxTestCase): key = self.seed_prng(0) is_range = type(input_range_or_shape) is int x = (input_range_or_shape if is_range else - self.rng().permutation(np.arange(np.prod( + self.rng().permutation(np.arange(math.prod( input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape)) N = x if is_range else x.shape[axis] if weighted: @@ -718,7 +719,7 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertEqual(np_shape, sample.shape) if not replace and shape: def lsort(x): - if not np.prod(x.shape): return x + if not math.prod(x.shape): return x ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis]))) return jnp.take(x, ind, axis) self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis))) @@ -741,7 +742,7 @@ class LaxRandomTest(jtu.JaxTestCase): is_range = type(range_or_shape) is int x = (range_or_shape if is_range else self.rng().permutation(np.arange( - np.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape)) + math.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape)) shape = ((range_or_shape,) if is_range else range_or_shape) x_ = np.copy(x) rand = lambda key, x: random.permutation(key, x, axis, independent=independent) @@ -750,7 +751,7 @@ class LaxRandomTest(jtu.JaxTestCase): self.assertFalse(np.all(perm == x)) # seems unlikely! arr = np.arange(x) if is_range else x def lsort(x): - if not np.prod(x.shape): return x + if not math.prod(x.shape): return x ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis]))) return jnp.take(x, ind, axis) if not independent: @@ -1602,7 +1603,7 @@ class KeyArrayTest(jtu.JaxTestCase): def make_keys(self, *shape, seed=None): if seed is None: seed = 28 - seeds = seed + jnp.arange(np.prod(shape), dtype=jnp.uint32) + seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32) make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl) return jnp.reshape(jax.vmap(make_key)(seeds), shape) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 3014a952f..f03cc43de 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -15,6 +15,7 @@ import contextlib from functools import partial import itertools +import math import operator import random import unittest @@ -174,7 +175,7 @@ all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): def _rand_sparse(shape, dtype, nse=nse): rand = rand_method(rng) - size = np.prod(shape).astype(int) + size = math.prod(shape) if 0 <= nse < 1: nse = nse * size nse = min(size, int(nse)) @@ -1554,7 +1555,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertEqual(mat.n_dense, out.n_dense) # Unnecessary padding eliminated - max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) + max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) self.assertLessEqual(out.nse, max_nse) @jtu.sample_product( @@ -1589,7 +1590,7 @@ class BCOOTest(sptu.SparseTestCase): self.assertEqual(mat.n_dense, out.n_dense) # Unnecessary padding eliminated - max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) + max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse]) self.assertLessEqual(out.nse, max_nse) @jtu.sample_product( @@ -1644,7 +1645,7 @@ class BCOOTest(sptu.SparseTestCase): [dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, nse=nse) for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)] for layout in iter_sparse_layouts(shape) - for nse in [None, np.prod(shape) - 1] + for nse in [None, math.prod(shape) - 1] ], dtype=jtu.dtypes.floating + jtu.dtypes.complex, remove_zeros=[True, False], @@ -2106,7 +2107,7 @@ class BCOOTest(sptu.SparseTestCase): M = rng(shape, dtype) def make_bcoo(M): - return sparse_bcoo._bcoo_fromdense(M, nse=np.prod(M.shape[:-1], dtype=int), n_dense=1) + return sparse_bcoo._bcoo_fromdense(M, nse=math.prod(M.shape[:-1]), n_dense=1) todense = partial(sparse_bcoo._bcoo_todense, spinfo=sparse_util.SparseInfo(shape)) @@ -2586,7 +2587,7 @@ class SparseObjectTest(sptu.SparseTestCase): self.assertIsInstance(M, Obj) self.assertEqual(M.shape, shape) - self.assertEqual(M.size, np.prod(shape)) + self.assertEqual(M.size, math.prod(shape)) self.assertEqual(M.ndim, len(shape)) self.assertEqual(M.dtype, dtype) self.assertEqual(M.nse, (M.todense() != 0).sum()) @@ -2757,8 +2758,8 @@ class SparseRandomTest(sptu.SparseTestCase): batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse]) approx_expected_num_nonzero = ( - np.ceil(0.2 * np.prod(sparse_shape)) - * np.prod(batch_shape) * np.prod(dense_shape)) + np.ceil(0.2 * math.prod(sparse_shape)) + * math.prod(batch_shape) * math.prod(dense_shape)) num_nonzero = (mat_dense != 0).sum() self.assertAlmostEqual(int(num_nonzero), approx_expected_num_nonzero, delta=2) diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index ff506f04a..7bfcba038 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import math import operator from absl.testing import absltest @@ -35,7 +36,7 @@ config.parse_flags_with_absl() def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default): def _rand_sparse(shape, dtype, nse=nse): rand = rand_method(rng) - size = np.prod(shape).astype(int) + size = math.prod(shape) if 0 <= nse < 1: nse = nse * size nse = min(size, int(nse)) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 9f9c6a7ab..9b6a62812 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -278,9 +278,9 @@ class XMapTest(XMapTestCase): out_axes=({0: 'a', 1: 'b'}, ['c', ...]), axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) - a = jnp.arange(np.prod(ashape)).reshape(ashape) + a = jnp.arange(math.prod(ashape)).reshape(ashape) bshape = (2, 7) - b = jnp.arange(np.prod(bshape)).reshape(bshape) + b = jnp.arange(math.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, a * 2) self.assertAllClose(d, b * 4) @@ -292,9 +292,9 @@ class XMapTest(XMapTestCase): out_axes=(['b', ...], {0: 'c'}), axis_resources={'a': 'x', 'b': 'y', 'c': 'x'}) ashape = (16, 8, 5) - a = jnp.arange(np.prod(ashape)).reshape(ashape) + a = jnp.arange(math.prod(ashape)).reshape(ashape) bshape = (2, 7) - b = jnp.arange(np.prod(bshape)).reshape(bshape) + b = jnp.arange(math.prod(bshape)).reshape(bshape) c, d = fm(a, b) self.assertAllClose(c, (a * 2).sum(0)) self.assertAllClose(d, b * 4) @@ -330,7 +330,7 @@ class XMapTest(XMapTestCase): fm = xmap(f, in_axes=['a', ...], out_axes=({}, {1: 'a'}), axis_resources={'a': ('x', 'y')}) vshape = (4, 5) - v = jnp.arange(np.prod(vshape)).reshape(vshape) + v = jnp.arange(math.prod(vshape)).reshape(vshape) ans, ans2 = fm(v) self.assertAllClose(ans, (v * 2).sum(0)) self.assertAllClose(ans2, v.T * 4) @@ -344,7 +344,7 @@ class XMapTest(XMapTestCase): fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'}, axis_resources={'a': ('y', 'x')}) vshape = (4, 5) - v = jnp.arange(np.prod(vshape)).reshape(vshape) + v = jnp.arange(math.prod(vshape)).reshape(vshape) zxy = fxy(v) zxy_op_sharding = zxy.sharding._to_xla_op_sharding(zxy.ndim) self.assertListEqual(zxy_op_sharding.tile_assignment_dimensions, [1, 4]) @@ -419,7 +419,7 @@ class XMapTest(XMapTestCase): return h(y) xshape = (4, 2, 5) - x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape) + x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape) y = f(x) self.assertAllClose( y, ((jnp.sin(x * 2) * @@ -825,7 +825,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest): in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}), axis_resources={'a': 'x', 'b': 'y'}) xshape = (8, 2, 4, 5) - x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape) + x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape) hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text() match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo) self.assertIsNot(match, None) @@ -1649,7 +1649,7 @@ class XMapErrorTest(jtu.JaxTestCase): return x return h(x) xshape = (2, 5, 6) - x = jnp.arange(np.prod(xshape)).reshape(xshape) + x = jnp.arange(math.prod(xshape)).reshape(xshape) with self.assertRaisesRegex(RuntimeError, "Changing the physical mesh is not allowed.*"): f(x)