diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py index e11319bd0..6ca555d3b 100644 --- a/benchmarks/pmap_benchmark.py +++ b/benchmarks/pmap_benchmark.py @@ -18,13 +18,15 @@ python3 pmap_benchmark.py To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2). """ + +import math + from absl import app import jax from jax import numpy as jnp from jax import pmap from jax.config import config -from jax._src.util import prod from benchmarks import benchmark @@ -118,7 +120,7 @@ def sharded_device_array_indexing_benchmark(): nshards = min(8, jax.local_device_count()) shape = (nshards, 8, 8) def benchmark_fn(): - arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape)) + arr = pmap(lambda x: x)(jnp.arange(math.prod(shape)).reshape(shape)) indices = indices_fn() for idx in indices: arr[idx] diff --git a/jax/_src/api.py b/jax/_src/api.py index eb5605c4f..236eec28f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -26,6 +26,7 @@ import collections import functools from functools import partial import inspect +import math import typing from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload) @@ -65,7 +66,7 @@ from jax._src.lib import pmap_lib from jax._src.sharding import PmapSharding from jax._src.traceback_util import api_boundary from jax._src.tree_util import broadcast_prefix, _generate_key_paths -from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list, +from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list, wrap_name, cache, wraps, HashableFunction, weakref_lru_cache) @@ -1034,7 +1035,7 @@ def xla_computation(fun: Callable, if axis_env is None: return xla.AxisEnv(nreps, (), ()) else: - nreps = nreps * prod(size for name, size in axis_env) + nreps = nreps * math.prod(size for name, size in axis_env) names, sizes = unzip2(axis_env) return xla.AxisEnv(nreps, names, sizes) @@ -3232,7 +3233,7 @@ class ShapeDtypeStruct: self.sharding = sharding self.named_shape = {} if named_shape is None else dict(named_shape) - size = property(lambda self: prod(self.shape)) + size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) def __len__(self): diff --git a/jax/_src/array.py b/jax/_src/array.py index 371f608f7..13bc11595 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +import math import operator as op import numpy as np import functools @@ -26,7 +27,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src.config import config -from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method +from jax._src.util import safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src import api @@ -208,7 +209,7 @@ class ArrayImpl(basearray.Array): @property def size(self): - return prod(self.shape) + return math.prod(self.shape) @property def sharding(self): @@ -552,12 +553,13 @@ def make_array_from_callback( Example: + >>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np ... >>> input_shape = (8, 8) - >>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape) + >>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) ... @@ -601,6 +603,7 @@ def make_array_from_single_device_arrays( Example: + >>> import math >>> from jax.sharding import Mesh >>> from jax.sharding import PartitionSpec as P >>> import numpy as np @@ -608,7 +611,7 @@ def make_array_from_single_device_arrays( >>> shape = (8, 8) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) >>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) - >>> inp_data = np.arange(prod(shape)).reshape(shape) + >>> inp_data = np.arange(math.prod(shape)).reshape(shape) ... >>> arrays = [ ... jax.device_put(inp_data[index], d) diff --git a/jax/_src/core.py b/jax/_src/core.py index 313c3b2f4..4bc192779 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -22,6 +22,7 @@ from functools import partial, partialmethod, total_ordering import gc import inspect import itertools as it +import math import operator from operator import attrgetter import threading @@ -44,7 +45,7 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError, from jax._src import linear_util as lu from jax._src import source_info_util -from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert, +from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, tuple_delete, as_hashable_function, HashableFunction, HashableWrapper, weakref_lru_cache) import jax._src.pretty_printer as pp @@ -1481,7 +1482,7 @@ class ShapedArray(UnshapedArray): return ShapedArray(shape, dtype, weak_type, named_shape) ndim = property(lambda self: len(self.shape)) - size = property(lambda self: prod(self.shape)) + size = property(lambda self: math.prod(self.shape)) broadcast: ClassVar[Optional[aval_method]] = None transpose: ClassVar[Optional[aval_method]] = None @@ -1628,7 +1629,7 @@ class DShapedArray(UnshapedArray): self.weak_type = weak_type ndim = property(lambda self: len(self.shape)) - size = property(lambda self: prod(self.shape)) + size = property(lambda self: math.prod(self.shape)) def str_short(self, short_dtypes=False) -> str: del short_dtypes # ignored diff --git a/jax/_src/global_device_array.py b/jax/_src/global_device_array.py index a858a735d..110adc45d 100644 --- a/jax/_src/global_device_array.py +++ b/jax/_src/global_device_array.py @@ -15,6 +15,7 @@ from collections import Counter import dataclasses import functools +import math import numpy as np from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple @@ -26,7 +27,7 @@ from jax._src.lib import xla_client as xc from jax._src.config import config from jax._src.interpreters import pxla from jax.interpreters import xla, mlir -from jax._src.util import prod, safe_zip +from jax._src.util import safe_zip from jax._src.interpreters.pxla import PartitionSpec Shape = Tuple[int, ...] @@ -87,7 +88,7 @@ def _get_shard_indices_replica_ids_uncached( out[device] = (index, replica_id) shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes) - expected_unique_shards = prod( + expected_unique_shards = math.prod( [g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0]) if expected_unique_shards != unique_shards: raise RuntimeError( @@ -104,7 +105,7 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape: if not mesh_axis: chunk_size.append(size) elif isinstance(mesh_axis, tuple): - m = prod([global_mesh.shape[ma] for ma in mesh_axis]) + m = math.prod([global_mesh.shape[ma] for ma in mesh_axis]) chunk_size.append(size // m) else: chunk_size.append(size // global_mesh.shape[mesh_axis]) @@ -323,7 +324,7 @@ class GlobalDeviceArray: @property def size(self): - return prod(self.shape) + return math.prod(self.shape) @property def mesh(self): @@ -455,7 +456,7 @@ class GlobalDeviceArray: >>> global_input_shape = (8, 8) >>> mesh_axes = P('x', 'y') >>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) - >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) ... >>> def cb(index): ... return global_input_data[index] @@ -505,7 +506,7 @@ class GlobalDeviceArray: >>> global_input_shape = (8, 2) >>> mesh_axes = P('x') >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) - >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) ... >>> def batched_cb(indices): ... assert len(indices) == len(global_mesh.local_devices) @@ -555,7 +556,7 @@ class GlobalDeviceArray: >>> global_input_shape = (8, 2) >>> mesh_axes = P(('x', 'y')) >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) - >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape) ... >>> def cb(cb_inp): ... dbs = [] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e3c454432..28cf433cf 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -37,6 +37,7 @@ import dataclasses from functools import partial, lru_cache, cached_property import itertools as it import logging +import math import operator as op import sys import threading @@ -78,7 +79,7 @@ from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list, +from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, wrap_name, assert_unreachable, tuple_insert, tuple_delete, distributed_debug_log, unzip2, HashableFunction) @@ -284,7 +285,8 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: if not has_unstacked: op_sharding_proto = sharding_spec_sharding_proto(self) return _op_sharding_to_numpy_indices( - op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape) + op_sharding_proto, shape, math.prod(self.mesh_shape) + ).reshape(self.mesh_shape) axis_indices: List[Sequence[Index]] = [] shard_indices_shape = [] @@ -312,7 +314,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray: # with each dimension having size equal to the number of shards across the corresponding # logical array dimension, and each element containing the multi-dimensional index that # is used to extract the corresponding shard of the logical array. - shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_) + shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_) for i, idxs in enumerate(it.product(*axis_indices)): shard_indices[i] = idxs shard_indices = shard_indices.reshape(shard_indices_shape) @@ -766,7 +768,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore @property def size(self): - return prod(self.aval.shape) + return math.prod(self.aval.shape) @property def ndim(self): @@ -2252,7 +2254,7 @@ core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst def _unravel_index_hlo(axis_env): div = mlir.ir_constant( - np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32)) + np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32)) mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32)) return hlo.RemOp( hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result @@ -4035,7 +4037,7 @@ class DynamicAxisEnv(list): @property def nreps(self): - return prod(frame.hard_size for frame in self) + return math.prod(frame.hard_size for frame in self) class _ThreadLocalState(threading.local): def __init__(self): diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index ef2af220f..b15b17f51 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -19,6 +19,7 @@ import dataclasses import functools from functools import partial import itertools as it +import math import operator import re from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol, @@ -37,7 +38,7 @@ from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad -from jax._src.util import (prod, safe_zip, safe_map, partition_list) +from jax._src.util import (safe_zip, safe_map, partition_list) from jax._src.typing import Shape @@ -311,7 +312,7 @@ def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]: if not isinstance(name, (list, tuple)): name = (name,) mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name)) - trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes)) + trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes)) assert not ragged mesh_spec = axis_env.sizes + (trailing_size,) return _axis_groups(mesh_spec, mesh_axes) @@ -326,10 +327,10 @@ def _axis_groups(mesh_spec, mesh_axes): Returns: A tuple of replica groups (i.e. tuples containing replica ids). """ - iota = np.arange(prod(mesh_spec)).reshape(mesh_spec) + iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec) groups = np.reshape( np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), - (prod(np.take(mesh_spec, mesh_axes)), -1)) + (math.prod(np.take(mesh_spec, mesh_axes)), -1)) return tuple(unsafe_map(tuple, groups.T)) diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index b199adacb..4fdac7e1d 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -14,6 +14,7 @@ from functools import partial +import math from typing import Union, Sequence import numpy as np @@ -30,7 +31,6 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax._src.lib import ducc_fft from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact -from jax._src.util import prod __all__ = [ "fft", @@ -150,7 +150,7 @@ def _irfft_transpose(t, fft_lengths): full(2.0, shape=(n - 2 + is_odd,)), full(1.0, shape=(1 - is_odd,))], dimension=0) - scale = 1 / prod(fft_lengths) + scale = 1 / math.prod(fft_lengths) out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1be99f42a..45be1ebd1 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -17,6 +17,7 @@ import enum import functools from functools import partial import itertools +import math import operator from typing import (Any, Callable, Optional, Sequence, Tuple, List, TypeVar, Union, cast as type_cast, overload) @@ -70,7 +71,7 @@ from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import PmapSharding from jax._src.typing import Array, ArrayLike, DTypeLike, Shape -from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis, +from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list) xb = xla_bridge @@ -1422,7 +1423,7 @@ def collapse(operand: Array, start_dimension: int, collapsed (raveled) into a single dimension. """ lo, hi = start_dimension, stop_dimension - size = prod(operand.shape[lo:hi]) + size = math.prod(operand.shape[lo:hi]) new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:] return reshape(operand, new_shape) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 6139d26d1..f0d94f7bc 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -15,6 +15,7 @@ import inspect import functools from functools import partial +import math from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload import warnings @@ -50,7 +51,6 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike -from jax._src.util import prod xops = xla_client.ops @@ -1313,7 +1313,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a): raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778") a_aval, taus_aval = ctx.avals_out *batch_dims, m, n = a_aval.shape - batch = prod(batch_dims) + batch = math.prod(batch_dims) if batch == 0 or m == 0 or n == 0: return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 679e01762..cc8b2fa4c 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast + import jax from jax._src.numpy import lax_numpy as jnp -from jax._src.util import prod from jax._src.lax import lax from jax._src.lax import convolution @@ -92,7 +92,7 @@ def conv_general_dilated_patches( lhs_spec, rhs_spec, out_spec = dimension_numbers - spatial_size = prod(filter_shape) + spatial_size = math.prod(filter_shape) n_channels = lhs_array.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d54cebaf5..7f5d6d84f 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -17,6 +17,7 @@ Parallelization primitives. from functools import partial import itertools +import math import string from typing import Sequence, Union import warnings @@ -40,7 +41,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy from jax._src.util import ( - unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis) + unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis) unsafe_map, map = map, safe_map # type: ignore @@ -828,7 +829,7 @@ def psum_bind(*args, axes, axis_index_groups): assert not pos_axes size = len(axis_index_groups[0]) else: - size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore + size = math.prod([core.axis_frame(name).size for name in named_axes]) # type: ignore return tuple(lax._const(x, size) * pos_reduce(x) for x in args) return core.AxisPrimitive.bind( psum_p, *args, axes=axes, axis_index_groups=axis_index_groups) @@ -1563,9 +1564,12 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): '`axis_index` translation rule does not support multiple axis names.') axis_name, = axis_name axis_pos = list(axis_env.names).index(axis_name) - nreplicas = axis_env.nreps // prod(axis_env.sizes) - div = mlir.ir_constant(np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]), - dtype=np.uint32)) + nreplicas = axis_env.nreps // math.prod(axis_env.sizes) + div = mlir.ir_constant( + np.array( + nreplicas * math.prod(axis_env.sizes[axis_pos + 1 :]), dtype=np.uint32 + ) + ) mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) axis_context = ctx.module_context.axis_context is_spmd = isinstance(axis_context, diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 728d2a4de..a7e10c70c 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -17,7 +17,7 @@ Common neural network layer initializers, consistent with definitions used in Keras and Sonnet. """ - +import math from typing import Any, Literal, Protocol, Sequence, Tuple, Union import numpy as np @@ -27,7 +27,6 @@ from jax import lax from jax import random from jax._src import core from jax._src import dtypes -from jax._src.util import prod KeyArray = random.KeyArray Array = Any @@ -549,7 +548,7 @@ def orthogonal(scale: RealNumeric = 1.0, dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") - n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis] + n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis] matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) A = random.normal(key, matrix_shape, dtype) Q, R = jnp.linalg.qr(A) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f6ee697cf..c806bb131 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -27,6 +27,7 @@ rules for the underlying :code:`lax` primitives. import builtins import collections from functools import partial +import math import operator import types from typing import ( @@ -82,7 +83,7 @@ from jax._src.numpy.util import ( # noqa: F401 from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape -from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, +from jax._src.util import (unzip2, subvals, safe_zip, ceil_of_ratio, partition_list, canonicalize_axis as _canonicalize_axis) from jax._src.array import ArrayImpl @@ -534,7 +535,7 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10, dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim] xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip') - hist = bincount(xy, weights, length=_prod(nbins)) + hist = bincount(xy, weights, length=math.prod(nbins)) hist = reshape(hist, nbins) core = D*(slice(1, -1),) hist = hist[core] @@ -914,7 +915,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array: arr = ravel(a) - new_size = _prod(new_shape) + new_size = math.prod(new_shape) if arr.size == 0 or new_size == 0: return zeros_like(arr, shape=new_shape) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 81f047faf..61d5be671 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -14,6 +14,7 @@ import builtins from functools import partial +import math import operator from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union import warnings @@ -31,7 +32,7 @@ from jax._src.numpy.util import ( from jax._src.lax import lax as lax_internal from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import ( - canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod) + canonicalize_axis as _canonicalize_axis, maybe_named_axis) _all = builtins.all @@ -371,7 +372,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = Non if axis is None: weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype) elif isinstance(axis, tuple): - weights_sum = lax.full_like(avg, _prod(core.dimension_as_value(a.shape[d]) for d in axis)) + weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis)) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 1926b99e6..2949415d3 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import math import operator from textwrap import dedent as _dedent from typing import Optional, Tuple, Union @@ -31,7 +32,6 @@ from jax._src.numpy.lax_numpy import ( sort, where, zeros) from jax._src.numpy.util import _check_arraylike, _wraps from jax._src.typing import Array, ArrayLike -from jax._src.util import prod as _prod _lax_const = lax_internal._const @@ -230,11 +230,11 @@ def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]: # is fixed to match numpy. aux = where(isnan(aux), _lax_const(aux, np.nan), aux) size, *out_shape = aux.shape - if _prod(out_shape) == 0: + if math.prod(out_shape) == 0: size = 1 perm = zeros(1, dtype=int) else: - perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1]) + perm = lexsort(aux.reshape(size, math.prod(out_shape)).T[::-1]) aux = aux[perm] if aux.size: if dtypes.issubdtype(aux.dtype, np.inexact): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 86a98f561..fbe7e93ee 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -15,6 +15,7 @@ import abc from functools import partial, reduce +import math import operator as op from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence @@ -44,7 +45,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy from jax._src.sharding import ( NamedSharding, PmapSharding, GSPMDSharding) -from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip +from jax._src.util import canonicalize_axis, safe_map, safe_zip map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -1141,7 +1142,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape): return _threefry_random_bits_original(key, bit_width, shape) def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape): - if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64: + if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64: raise NotImplementedError('random bits array of size exceeding 2 ** 64') k1, k2 = key @@ -1160,7 +1161,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape): @partial(jit, static_argnums=(1, 2), inline=True) def _threefry_random_bits_original(key: jax.Array, bit_width, shape): - size = prod(shape) + size = math.prod(shape) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) diff --git a/jax/_src/random.py b/jax/_src/random.py index b7dcaed97..6ac9f852b 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -14,8 +14,9 @@ from functools import partial -from typing import Optional, Sequence, Union +import math from operator import index +from typing import Optional, Sequence, Union import warnings import numpy as np @@ -40,7 +41,7 @@ from jax._src.numpy.lax_numpy import ( _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact) from jax._src.typing import Array, ArrayLike, DTypeLike -from jax._src.util import prod, canonicalize_axis +from jax._src.util import canonicalize_axis RealArray = ArrayLike @@ -507,7 +508,7 @@ def choice(key: KeyArray, else: axis = canonicalize_axis(axis, arr.ndim) n_inputs = arr.shape[axis] - n_draws = prod(shape) + n_draws = math.prod(shape) if n_draws == 0: return jnp.zeros(shape, dtype=arr.dtype) if n_inputs <= 0: @@ -1025,7 +1026,7 @@ def _gamma_grad(sample, a, *, log_space): def _gamma_impl(key, a, *, log_space, use_vmap=False): # split key to match the shape of a a_shape = jnp.shape(a) - split_count = prod(a_shape[key.ndim:]) + split_count = math.prod(a_shape[key.ndim:]) keys = key.flatten() keys = vmap(_split, in_axes=(0, None))(keys, split_count) keys = keys.flatten() diff --git a/jax/_src/scipy/stats/_core.py b/jax/_src/scipy/stats/_core.py index 440147975..81925fb4c 100644 --- a/jax/_src/scipy/stats/_core.py +++ b/jax/_src/scipy/stats/_core.py @@ -14,6 +14,7 @@ from collections import namedtuple from functools import partial +import math from typing import Optional, Tuple import jax @@ -24,7 +25,7 @@ from jax._src.api import vmap from jax._src.numpy.lax_numpy import _check_arraylike from jax._src.numpy.util import _wraps from jax._src.typing import ArrayLike, Array -from jax._src.util import canonicalize_axis, prod +from jax._src.util import canonicalize_axis import scipy @@ -80,7 +81,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k axis = canonicalize_axis(axis, x.ndim) x = jnp.moveaxis(x, axis, 0) - x = x.reshape(x.shape[0], prod(x.shape[1:])) + x = x.reshape(x.shape[0], math.prod(x.shape[1:])) vals, counts = vmap(_mode_helper, in_axes=1)(x) return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape)) diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 68c01c65c..29d3ccce1 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -14,6 +14,7 @@ """Module for state types.""" from __future__ import annotations +import math from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union from jax._src import core @@ -21,7 +22,7 @@ from jax._src import effects from jax._src import pretty_printer as pp from jax._src import xla_bridge from jax._src.lib import xla_client -from jax._src.util import safe_map, safe_zip, prod +from jax._src.util import safe_map, safe_zip xc = xla_client xb = xla_bridge @@ -89,7 +90,7 @@ class AbstractRef(core.AbstractValue, Generic[Aval]): return AbstractRef(self.inner_aval.join(other.inner_aval)) ndim = property(lambda self: len(self.shape)) - size = property(lambda self: prod(self.shape)) + size = property(lambda self: math.prod(self.shape)) @property def shape(self): diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7dea03d02..c204fe6f1 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -17,6 +17,7 @@ import inspect import io import functools from functools import partial +import math import re import os import tempfile @@ -46,7 +47,7 @@ from jax._src.config import (flags, bool_env, config, raise_persistent_cache_errors, persistent_cache_min_compile_time_secs) from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact -from jax._src.util import prod, unzip2 +from jax._src.util import unzip2 from jax._src.public_test_util import ( # noqa: F401 _assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads, check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance) @@ -668,7 +669,7 @@ def rand_int(rng, low=0, high=None): def rand_unique_int(rng, high=None): def fn(shape, dtype): - return rng.choice(np.arange(high or prod(shape), dtype=dtype), + return rng.choice(np.arange(high or math.prod(shape), dtype=dtype), size=shape, replace=False) return fn @@ -755,7 +756,7 @@ def sample_product_testcases(*args, **kw): """Non-decorator form of sample_product.""" args = [list(arg) for arg in args] kw = [(k, list(v)) for k, v in kw.items()] - n = prod(len(a) for a in args) * prod(len(v) for _, v in kw) + n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw) testcases = [] for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)): testcase = {} @@ -1054,7 +1055,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]: """Test utility for setting up meshes given mesh data from `schedules`.""" # This is similar to the `with_mesh` function above, but isn't a decorator. axis_names, shape = unzip2(named_shape) - size = prod(shape) + size = math.prod(shape) local_devices = list(api.local_devices()) if len(local_devices) < size: raise unittest.SkipTest(f"Test requires {size} local devices") @@ -1094,7 +1095,7 @@ def restore_spmd_manual_lowering_flag(): config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag) def create_global_mesh(mesh_shape, axis_names): - size = prod(mesh_shape) + size = math.prod(mesh_shape) if len(api.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") devices = sorted(api.devices(), key=lambda d: d.id) diff --git a/jax/_src/util.py b/jax/_src/util.py index efc31eb1a..33e95499d 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -269,12 +269,6 @@ def weakref_lru_cache(call: Callable, maxsize=2048): """ return xc.weakref_lru_cache(config._trace_context, call, maxsize) -def prod(xs): - out = 1 - for x in xs: - out *= x - return out - class Unhashable: __slots__ = ["val"] diff --git a/jax/experimental/rnn.py b/jax/experimental/rnn.py index cbf1a66dc..e6261486d 100644 --- a/jax/experimental/rnn.py +++ b/jax/experimental/rnn.py @@ -83,6 +83,7 @@ TODO: - Support RNNs other than LSTM. """ from functools import partial +import math from typing import Any, Dict, List, Tuple import jax @@ -92,7 +93,6 @@ from jax.interpreters import mlir from jax.interpreters import xla from jax._src.custom_derivatives import custom_vjp from jax._src.typing import Array, Shape -from jax._src.util import prod import jax.numpy as jnp try: from jax._src.lib import gpu_rnn @@ -161,7 +161,7 @@ def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int, """Get param count in LSTM.""" layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers, bidirectional) - param_count = sum([prod(shape) for shape in layer_shapes]) + param_count = sum([math.prod(shape) for shape in layer_shapes]) return param_count @@ -201,7 +201,7 @@ def unpack_lstm_weights( for w_kind in [W_ih, W_hh]: shape = flat_shapes[flat_shapes_offset] flat_shapes_offset += 1 - num_elems = prod(shape) + num_elems = math.prod(shape) w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape) w_offsets += num_elems @@ -211,7 +211,7 @@ def unpack_lstm_weights( for w_kind in [b_ih, b_hh]: shape = flat_shapes[flat_shapes_offset] flat_shapes_offset += 1 - num_elems = prod(shape) + num_elems = math.prod(shape) w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape) w_offsets += num_elems return W_ih, W_hh, b_ih, b_hh diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index ec961bbd5..3d503a9f4 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -17,6 +17,7 @@ import enum from functools import partial, lru_cache import inspect import itertools as it +import math import operator as op from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, Protocol) @@ -37,7 +38,7 @@ from jax._src import util from jax._src.core import Tracer from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, fft, linalg) -from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function, +from jax._src.util import (HashableFunction, unzip2, as_hashable_function, memoize, partition_list, merge_lists) from jax.api_util import flatten_fun_nokwargs, shaped_abstractify from jax.interpreters import batching @@ -150,7 +151,7 @@ def _check_specs_vs_args( msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) - fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns) + fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns) for d, ns in names.items()) else no_fail for a, names in zip(in_avals, in_names_flat)] if any(f is not no_fail for f in fail): @@ -201,10 +202,10 @@ def _spec_divisibility_error( f"parameter '{list(ba.arguments.keys())[arg_key.key]}',") names = _canonicalize_spec(spec) for d, ns in names.items(): - if aval.shape[d] % prod(mesh.shape[n] for n in ns): + if aval.shape[d] % math.prod(mesh.shape[n] for n in ns): axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" total = 'total ' if len(ns) > 1 else '' - sz = prod(mesh.shape[n] for n in ns) + sz = math.prod(mesh.shape[n] for n in ns) msgs.append( f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} " f"corresponds to in_specs{spec_key.pprint()} of value {spec}, " @@ -382,7 +383,7 @@ pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue ) -> core.AbstractValue: if isinstance(aval, core.ShapedArray): - return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape))) else: raise NotImplementedError # TODO(mattjj): add table with handlers @@ -390,7 +391,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue ) -> core.AbstractValue: if isinstance(aval, core.ShapedArray): - return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)), named_shape={k: v for k, v in aval.named_shape.items() if k not in mesh.shape}) @@ -880,7 +881,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, check_rep): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero - else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) + else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns)))) for ns, x in zip(out_names, out_cts)] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval)) diff --git a/jax/experimental/sparse/_base.py b/jax/experimental/sparse/_base.py index d233ab921..c246c9194 100644 --- a/jax/experimental/sparse/_base.py +++ b/jax/experimental/sparse/_base.py @@ -14,11 +14,11 @@ """Base JAX Sparse object.""" import abc +import math from typing import Sequence, Tuple import jax from jax._src import core -from jax._src import util from jax._src.typing import Array @@ -37,7 +37,7 @@ class JAXSparse(abc.ABC): @property def size(self) -> int: - return util.prod(self.shape) + return math.prod(self.shape) @property def ndim(self) -> int: diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index df4d42a78..601e7d402 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -171,7 +171,6 @@ from jax._src.lax.lax import ( population_count_p as population_count_p, pow as pow, pow_p as pow_p, - prod as prod, random_gamma_grad as random_gamma_grad, random_gamma_grad_p as random_gamma_grad_p, real as real, @@ -369,3 +368,5 @@ from jax.lax import linalg as linalg from jax._src.pjit import with_sharding_constraint from jax._src.dispatch import device_put_p + +from math import prod # TODO(phawkins): remove this accidental export diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index ac03f2f9e..14db64b50 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -13,9 +13,8 @@ # limitations under the License. -import functools from functools import partial -import operator +import math import jaxlib.mlir.ir as ir import jaxlib.mlir.dialects.stablehlo as hlo @@ -60,8 +59,6 @@ def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype -_prod = lambda xs: functools.reduce(operator.mul, xs, 1) - def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): """LU decomposition.""" @@ -71,7 +68,7 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - batch = _prod(batch_dims) + batch = math.prod(batch_dims) if batch > 1 and m == n and m // batch <= 128: lwork, opaque = gpu_blas.build_getrf_batched_descriptor( @@ -118,7 +115,7 @@ def _geqrf_hlo(platform, gpu_solver, dtype, a): m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - batch = _prod(batch_dims) + batch = math.prod(batch_dims) lwork, opaque = gpu_solver.build_geqrf_descriptor( np.dtype(dtype), batch, m, n) @@ -156,7 +153,7 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a): m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - batch = _prod(batch_dims) + batch = math.prod(batch_dims) lwork, opaque = gpu_blas.build_geqrf_batched_descriptor( np.dtype(dtype), batch, m, n) @@ -220,7 +217,7 @@ def _orgqr_hlo(platform, gpu_solver, dtype, a, tau): m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - batch = _prod(batch_dims) + batch = math.prod(batch_dims) tau_dims = ir.RankedTensorType(tau.type).shape assert tau_dims[:-1] == dims[:-2] @@ -266,7 +263,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, assert m == n batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - batch = _prod(batch_dims) + batch = math.prod(batch_dims) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) if have_jacobi_solver and n <= 32: @@ -317,7 +314,7 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, m, n = dims[-2:] batch_dims = tuple(dims[:-2]) num_bd = len(batch_dims) - b = _prod(batch_dims) + b = math.prod(batch_dims) if ir.ComplexType.isinstance(a_type.element_type): singular_vals_type = ir.ComplexType(a_type.element_type).element_type else: diff --git a/tests/ann_test.py b/tests/ann_test.py index 77ba493c9..359a9a0e4 100644 --- a/tests/ann_test.py +++ b/tests/ann_test.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +import math from absl.testing import absltest @@ -21,7 +22,6 @@ import numpy as np import jax from jax import lax from jax._src import test_util as jtu -from jax._src.util import prod from jax.config import config @@ -104,7 +104,7 @@ class AnnTest(jtu.JaxTestCase): is_max_k=[True, False], ) def test_autodiff(self, shape, dtype, k, is_max_k): - vals = np.arange(prod(shape), dtype=dtype) + vals = np.arange(math.prod(shape), dtype=dtype) vals = self.rng().permutation(vals).reshape(shape) if is_max_k: fn = lambda vs: lax.approx_max_k(vs, k=k)[0] diff --git a/tests/array_test.py b/tests/array_test.py index d163dd324..ae196c60f 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -14,6 +14,7 @@ """Tests for GlobalDeviceArray.""" import contextlib +import math import os import unittest from absl.testing import absltest @@ -27,7 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc -from jax._src.util import prod, safe_zip +from jax._src.util import safe_zip from jax.interpreters import pxla from jax.experimental.pjit import pjit from jax.experimental.serialize_executable import ( @@ -72,7 +73,7 @@ def tearDownModule(): def create_array(shape, sharding, global_data=None): if global_data is None: - global_data = np.arange(prod(shape)).reshape(shape) + global_data = np.arange(math.prod(shape)).reshape(shape) return array.make_array_from_callback( shape, sharding, lambda idx: global_data[idx]), global_data @@ -310,7 +311,7 @@ class JaxArrayTest(jtu.JaxTestCase): shape = (8, 2) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) s = sharding.NamedSharding(mesh, P('x', 'y')) - inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) di_map = s.devices_indices_map(shape) bufs = [jax.device_put(inp_data[di_map[d]], d) for d in jax.local_devices()] @@ -333,7 +334,7 @@ class JaxArrayTest(jtu.JaxTestCase): mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) # sharding device ids = {0, 1} s = sharding.NamedSharding(mesh, P('x')) - inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) # _arrays device ids = {2, 3} bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]] with self.assertRaisesRegex( @@ -349,7 +350,7 @@ class JaxArrayTest(jtu.JaxTestCase): mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} s = sharding.NamedSharding(mesh, P('x')) - inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) # _arrays device ids = {0, 0} bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)] with self.assertRaisesRegex( @@ -366,7 +367,7 @@ class JaxArrayTest(jtu.JaxTestCase): mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x')) # sharding device ids = {1, 2} s = sharding.NamedSharding(mesh, P('x')) - inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) # _arrays device ids = {0, 1} bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]] with self.assertRaisesRegex( @@ -389,7 +390,7 @@ class JaxArrayTest(jtu.JaxTestCase): shape = (8, 4) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) mps = sharding.NamedSharding(mesh, pspec) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) str_expected_shard_shape = str(expected_shard_shape).replace( r"(", r"\(").replace(r")", r"\)") @@ -403,7 +404,7 @@ class JaxArrayTest(jtu.JaxTestCase): shape = (8, 2) mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) s = sharding.NamedSharding(mesh, P('x', 'y')) - inp_data = np.arange(prod(shape), dtype=np.int32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) indices = s.devices_indices_map(shape) bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices] with self.assertRaisesRegex( @@ -725,7 +726,7 @@ class ShardingTest(jtu.JaxTestCase): self.skipTest('Test needs >= 2 devices.') shape = (2, 2) - num_elements = prod(shape) + num_elements = math.prod(shape) inp_data = np.arange(num_elements).reshape(shape) out = jax.pmap(lambda x: x)(inp_data) self.assertIsInstance(out.sharding, sharding.PmapSharding) @@ -954,7 +955,7 @@ class RngShardingTest(jtu.JaxTestCase): mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) s = sharding.NamedSharding(mesh, pspec) - n = prod(global_shape) + n = math.prod(global_shape) global_x = jnp.arange(n).astype('uint32').reshape(global_shape) x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i]) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 0f6f4be50..783cb14ab 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -14,6 +14,7 @@ from functools import partial import hashlib +import math import os import random import sys @@ -29,7 +30,6 @@ from jax.experimental.maps import xmap from jax.experimental.pjit import pjit import jax from jax import jit, lax, pmap -from jax._src.util import prod import jax._src.test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client @@ -284,11 +284,11 @@ class CompilationCacheTest(jtu.JaxTestCase): return x + y shape = (8, 8) - x = np.arange(prod(shape), dtype=np.int64).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 1) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f(x, x + 1) files_in_directory = len(os.listdir(tmpdir)) self.assertEqual(files_in_directory, 2) diff --git a/tests/global_device_array_test.py b/tests/global_device_array_test.py index edfaaffad..4c76bc854 100644 --- a/tests/global_device_array_test.py +++ b/tests/global_device_array_test.py @@ -13,15 +13,17 @@ # limitations under the License. """Tests for GlobalDeviceArray.""" -from absl.testing import absltest -from absl.testing import parameterized +import math import unittest import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + import jax from jax._src import core from jax._src import test_util as jtu -from jax._src.util import prod, safe_zip +from jax._src.util import safe_zip from jax.sharding import PartitionSpec as P from jax.sharding import Mesh @@ -34,7 +36,7 @@ config.parse_flags_with_absl() def create_gda(global_shape, global_mesh, mesh_axes, global_data=None): if global_data is None: - global_data = np.arange(prod(global_shape)).reshape(global_shape) + global_data = np.arange(math.prod(global_shape)).reshape(global_shape) return GlobalDeviceArray.from_callback( global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data @@ -75,7 +77,7 @@ class GDATest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -128,7 +130,7 @@ class GDATest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) global_input_shape = (8, 4, 2) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -163,7 +165,7 @@ class GDATest(jtu.JaxTestCase): expected_replica_ids): global_mesh = jtu.create_global_mesh((8,), ('x')) global_input_shape = (16,) - global_input_data = np.arange(prod(global_input_shape)).reshape(-1) + global_input_data = np.arange(math.prod(global_input_shape)).reshape(-1) def cb(index): return global_input_data[index] @@ -214,7 +216,7 @@ class GDATest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (8, 2) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -240,7 +242,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(indices): self.assertEqual(len(indices), len(global_mesh.local_devices)) @@ -260,7 +262,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x') global_input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) def cb(cb_inp): self.assertLen(cb_inp, 4) @@ -288,7 +290,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] gda = GlobalDeviceArray.from_callback( @@ -307,7 +309,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(None,) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] input_gda = GlobalDeviceArray.from_callback( @@ -340,7 +342,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes) dbs = [ @@ -358,7 +360,7 @@ class GDATest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(('x', 'y')) global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return global_input_data[index] diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 07ad0849c..76cf17b1b 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -16,6 +16,7 @@ import collections from functools import partial import itertools +import math from unittest import SkipTest from absl.testing import absltest @@ -28,7 +29,6 @@ from jax import dtypes from jax import lax from jax._src import test_util as jtu from jax.test_util import check_grads -from jax._src.util import prod from jax.config import config config.parse_flags_with_absl() @@ -832,7 +832,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(prod(shape), dtype=key_dtype) + flat_keys = np.arange(math.prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -847,7 +847,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): k=[1, 3], ) def testTopKGrad(self, shape, dtype, k): - flat_values = np.arange(prod(shape), dtype=dtype) + flat_values = np.arange(math.prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 7e69469c8..8ea1cae6e 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -19,6 +19,7 @@ from functools import partial import inspect import io import itertools +import math from typing import cast, Iterator, Optional, List, Tuple import unittest from unittest import SkipTest @@ -46,7 +47,7 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps -from jax._src.util import prod, safe_zip +from jax._src.util import safe_zip from jax._src import array from jax.config import config @@ -1358,7 +1359,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if shape in scalar_shapes or len(shape) == 0: cond_shape = (0,) elif axis is None: - cond_shape = (prod(shape),) + cond_shape = (math.prod(shape),) else: cond_shape = (shape[axis],) @@ -1394,7 +1395,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if shape in scalar_shapes or len(shape) == 0: cond_shape = (0,) elif axis is None: - cond_shape = (prod(shape),) + cond_shape = (math.prod(shape),) else: cond_shape = (shape[axis],) @@ -1488,7 +1489,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): [dict(shape=shape, axis=axis, idx=idx) for shape in nonempty_nonscalar_array_shapes for axis in [None] + list(range(-len(shape), len(shape))) - for idx in (range(-prod(shape), prod(shape)) + for idx in (range(-math.prod(shape), math.prod(shape)) if axis is None else range(-shape[axis], shape[axis]))], dtype=all_dtypes, @@ -3372,7 +3373,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): idx_shape=all_shapes, ) def testUnravelIndex(self, shape, idx_shape, dtype): - size = prod(shape) + size = math.prod(shape) rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3) def np_fun(index, shape): diff --git a/tests/lax_test.py b/tests/lax_test.py index 39434a978..4bd7e7a1a 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -15,6 +15,7 @@ from __future__ import annotations from functools import partial import itertools +import math import operator import types import unittest @@ -44,7 +45,6 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import test_util as jtu from jax._src import lax_reference -from jax._src.util import prod from jax._src.lax import lax as lax_internal from jax._src.internal_test_util import lax_test_util @@ -1998,7 +1998,7 @@ class LaxTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(prod(shape), dtype=key_dtype) + flat_keys = np.arange(math.prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -2035,7 +2035,7 @@ class LaxTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(prod(shape), dtype=key_dtype) + flat_keys = np.arange(math.prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -2051,7 +2051,7 @@ class LaxTest(jtu.JaxTestCase): ) def testTopK(self, shape, dtype, k): def args_maker(): - flat_values = np.arange(prod(shape), dtype=dtype) + flat_values = np.arange(math.prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) return [values] def reference_top_k(x): diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index b1fdb36e9..58e32b674 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -15,6 +15,7 @@ from functools import partial import itertools +import math from typing import Optional, cast import unittest @@ -32,7 +33,7 @@ from jax._src import test_util as jtu from jax._src.internal_test_util import lax_test_util from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client -from jax._src.util import prod, safe_map, safe_zip +from jax._src.util import safe_map, safe_zip from jax.config import config config.parse_flags_with_absl() @@ -667,7 +668,7 @@ class LaxVmapTest(jtu.JaxTestCase): # Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of # values a bfloat16 can represent exactly to avoid ties. def testTopK(self, shape, dtype, k, bdims): - rng = jtu.rand_int(self.rng(), high=prod(shape)) + rng = jtu.rand_int(self.rng(), high=math.prod(shape)) # _CheckBatching doesn't work with tuple outputs, so test outputs separately. op1 = lambda x: lax.top_k(x, k=k)[0] self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index f05eeb496..55712f60c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -16,6 +16,7 @@ import os import re from functools import partial, lru_cache import logging +import math import threading import unittest from collections import OrderedDict, namedtuple @@ -51,7 +52,7 @@ from jax._src.interpreters import pxla from jax.interpreters import mlir from jax._src import xla_bridge from jax._src.lib import xla_client as xc -from jax._src.util import prod, curry, unzip2, safe_zip +from jax._src.util import curry, unzip2, safe_zip from jax.config import config config.parse_flags_with_absl() @@ -85,7 +86,7 @@ def create_gda(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): if global_data is None: global_data = np.arange( - prod(global_shape), dtype=dtype).reshape(global_shape) + math.prod(global_shape), dtype=dtype).reshape(global_shape) if isinstance(mesh_axes, Sharding): mesh_axes = mesh_axes.spec @@ -98,7 +99,7 @@ def create_array(global_shape, global_mesh, mesh_axes, global_data=None, dtype=np.float32): if global_data is None: global_data = np.arange( - prod(global_shape), dtype=dtype).reshape(global_shape) + math.prod(global_shape), dtype=dtype).reshape(global_shape) if isinstance(mesh_axes, Sharding): sharding = mesh_axes @@ -144,7 +145,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x shape = (2, 2) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) actual = f(x) expected = x self.assertAllClose(actual, expected, check_dtypes=False) @@ -164,7 +165,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x + y shape = (8, 8) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) @@ -182,7 +183,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x + y shape = (8, 8) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) if config.jax_array: out = jax.jit(f)(x, x + 1) self.assertArraysEqual(out, x + x + 1) @@ -205,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase): return jnp.pad(out, [[0, 1]]) shape = (4,) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual[:3], expected[:3], check_dtypes=False) @@ -222,7 +223,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x + y shape = (8, 8) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) with jtu.create_global_mesh((2,), ('x')) as mesh: actual = f(x, x + 1) expected = x + (x + 1) @@ -281,7 +282,7 @@ class PJitTest(jtu.BufferDonationTestCase): def testMeshDecorator(self): x = jnp.arange(8) mesh_shape = (2, 2) - size = prod(mesh_shape) + size = math.prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape) @@ -347,7 +348,7 @@ class PJitTest(jtu.BufferDonationTestCase): return y * 2 shape = (8, 8) - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) expected = (x + 1) * 2 actual = f(x) self.assertAllClose(actual, expected, check_dtypes=False) @@ -374,7 +375,7 @@ class PJitTest(jtu.BufferDonationTestCase): return y * 2 shape = (8, 8) - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) expected = (x + 1) * 2 actual = f(x) self.assertAllClose(actual, expected, check_dtypes=False) @@ -402,7 +403,7 @@ class PJitTest(jtu.BufferDonationTestCase): y = with_sharding_constraint(y, ops) return y * 2 - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) expected = (x + 1) * 2 actual = f(x) self.assertAllClose(actual, expected, check_dtypes=False) @@ -426,7 +427,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x shape = (8, 8) - v = np.arange(prod(shape)).reshape(shape) + v = np.arange(math.prod(shape)).reshape(shape) x = [{"a": v, "b": v * 2}, v * 3] actual = f(x) @@ -458,7 +459,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x shape = (8, 8) - v = np.arange(prod(shape)).reshape(shape) + v = np.arange(math.prod(shape)).reshape(shape) x = [{"a": v, "b": v * 2}, v * 3] actual = f(x) @@ -487,7 +488,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x shape = (2, 8, 8) - v = np.arange(prod(shape)).reshape(shape) + v = np.arange(math.prod(shape)).reshape(shape) x = [{'a': v, 'b': v * 2}, v * 3] actual = f(x) @@ -513,7 +514,7 @@ class PJitTest(jtu.BufferDonationTestCase): return x shape = (2, 8, 8) - v = np.arange(prod(shape)).reshape(shape) + v = np.arange(math.prod(shape)).reshape(shape) x = [{'a': v, 'b': v * 2}, v * 3] mlir_str = str(f.lower(x).compiler_ir()) @@ -1084,7 +1085,7 @@ class PJitTest(jtu.BufferDonationTestCase): input_shape = (8, 4) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) seeds = jnp.arange( - prod(input_shape), dtype=np.uint32).reshape(input_shape) + math.prod(input_shape), dtype=np.uint32).reshape(input_shape) with mesh: def make_keys(seeds): @@ -1179,7 +1180,7 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1214,7 +1215,7 @@ class GDAPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1288,7 +1289,7 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1321,7 +1322,7 @@ class GDAPjitTest(jtu.JaxTestCase): @jtu.with_mesh([('x', 4), ('y', 2)]) def test_pjit_gda_non_gda_inputs(self): input_shape = (8, 2) - input_data = np.arange(prod(input_shape)).reshape(input_shape) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) with parallel_functions_output_gda(True): @partial(pjit, @@ -1353,7 +1354,8 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') global_input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32 + ).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -1374,7 +1376,8 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P('x') global_input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32 + ).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -1400,7 +1403,7 @@ class GDAPjitTest(jtu.JaxTestCase): input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(input_shape), dtype=np.float32).reshape(input_shape) + math.prod(input_shape), dtype=np.float32).reshape(input_shape) def cb(index): return input_data[index] @@ -1439,7 +1442,8 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(None) global_input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32 + ).reshape(global_input_shape) def cb(index): return global_input_data[index] @@ -1500,7 +1504,7 @@ class GDAPjitTest(jtu.JaxTestCase): global_input_shape = (8, 2) mesh_axes = P(None) global_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) with parallel_functions_output_gda(True): f = pjit(lambda x: x, in_shardings=mesh_axes, out_shardings=mesh_axes) @@ -1584,7 +1588,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): raise unittest.SkipTest('GDA and Array cannot be together.') global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with parallel_functions_output_gda(True): with global_mesh: @@ -1610,7 +1614,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.') global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with jax_array(True): with global_mesh: @@ -1634,7 +1638,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with ctx(True): with global_mesh: @@ -1662,7 +1666,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with global_mesh: f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO) @@ -1685,7 +1689,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) in_resource = pspec @@ -1718,7 +1722,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) in_resource = NamedSharding(global_mesh, pspec) @@ -1748,7 +1752,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( - prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) + math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) with jax_array(True): with global_mesh: @@ -1835,7 +1839,7 @@ class ArrayPjitTest(jtu.JaxTestCase): input_shape = (8, 2) global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_data = np.arange( - prod(input_shape), dtype=np.float32).reshape(input_shape) + math.prod(input_shape), dtype=np.float32).reshape(input_shape) with jax_array(True): with global_mesh: f = pjit(lambda x: x, @@ -1854,7 +1858,7 @@ class ArrayPjitTest(jtu.JaxTestCase): input_shape = (8, 2) global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) input_data = np.arange( - prod(input_shape), dtype=np.float32).reshape(input_shape) + math.prod(input_shape), dtype=np.float32).reshape(input_shape) with jax_array(True): with global_mesh: f = pjit( @@ -2026,7 +2030,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_globally_sharded_key_array_result_8x4_single_device(self): input_shape = (8, 4) seeds = jnp.arange( - prod(input_shape), dtype=np.uint32).reshape(input_shape) + math.prod(input_shape), dtype=np.uint32).reshape(input_shape) @pjit def make_keys(seeds): @@ -2080,7 +2084,7 @@ class ArrayPjitTest(jtu.JaxTestCase): m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') - a1 = jnp.arange(prod(input_shape)).reshape(input_shape) + a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) with jax_array(True): with m1: @@ -2096,7 +2100,7 @@ class ArrayPjitTest(jtu.JaxTestCase): m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') - a1 = jnp.arange(prod(input_shape)).reshape(input_shape) + a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) with jax_array(True): with m1: @@ -2250,7 +2254,7 @@ class ArrayPjitTest(jtu.JaxTestCase): shape = (8, 2) mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) - inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape) + inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) # Explicitly put on the ordering of devices which does not match the mesh # ordering to make sure we reorder them in the constructor and the output @@ -2271,7 +2275,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_not_xlacompatible_sharding_error(self): shape = (8, 2) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) ts = TempSharding(jax.devices()) arr = array.make_array_from_callback( shape, ts, lambda idx: inp_data[idx]) @@ -2333,7 +2337,7 @@ class ArrayPjitTest(jtu.JaxTestCase): @jax_array(True) def test_pjit_uncommitted_array_and_committed_array(self): shape = (8, 2) - uarr = jnp.arange(prod(shape), dtype=np.float32).reshape(shape) + uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) carr, inp_data = create_array(shape, mesh, P('x', 'y')) with mesh: @@ -2357,7 +2361,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_pjit_uncommitted_array_multi_devices(self): shape = (8, 2) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - inp = np.arange(prod(shape), dtype=np.int32).reshape(shape) + inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) arr = array.ArrayImpl( core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)), [jax.device_put(inp, d) for d in mesh.devices.flat], committed=False) @@ -2602,7 +2606,7 @@ class ArrayPjitTest(jtu.JaxTestCase): shape = (8, 2) mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(inp_data, s) out = pjit(lambda x: x)(arr) self.assertArraysEqual(out, inp_data) @@ -2640,7 +2644,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_multi_device_pjit_mul(self): shape = (8, 2) mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y'))) @@ -2655,7 +2659,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_single_device_pjit_cpp_dispatch(self): shape = (8, 2) mesh = jtu.create_global_mesh((1,), ('x',)) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None) with jtu.count_pjit_cache_miss() as count: @@ -3379,7 +3383,7 @@ class ArrayPjitTest(jtu.JaxTestCase): def test_with_sharding_constraint_spmd_axis_name(self): mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl')) shape = (8, 4, 2, 2) - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) def f(inp): return with_sharding_constraint(inp, P('data', None, None)) @@ -3723,7 +3727,7 @@ class PJitErrorTest(jtu.JaxTestCase): def test_pjit_with_deleted_input_at_first_call(self, committed): shape = (8,) mesh = jtu.create_global_mesh((1,), ('x',)) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) x = jax.device_put(inp_data, s) @@ -3742,7 +3746,7 @@ class PJitErrorTest(jtu.JaxTestCase): def test_pjit_with_deleted_input_at_subsequent_call(self, committed): shape = (8,) mesh = jtu.create_global_mesh((1,), ('x',)) - inp_data = np.arange(prod(shape)).reshape(shape) + inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) x = jax.device_put(inp_data, s) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index c01b5768c..ebddf3ae2 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -17,6 +17,7 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial import itertools as it import gc +import math import os from random import shuffle from typing import Optional, cast @@ -44,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr, from jax._src import config as jax_config from jax._src import device_array from jax._src import xla_bridge -from jax._src.util import prod, safe_map, safe_zip +from jax._src.util import safe_map, safe_zip from jax.interpreters import pxla from jax.interpreters import xla from jax._src import array @@ -119,7 +120,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None, aval = ShapedArray(input_shape, dtype) if input_data is None: - input_data = np.arange(prod(input_shape)).reshape(input_shape) + input_data = np.arange(math.prod(input_shape)).reshape(input_shape) sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size) @@ -167,9 +168,9 @@ class PythonPmapTest(jtu.JaxTestCase): msg = "device mesh shape {} not compatible with device count {}" raise SkipTest(msg.format(device_mesh_shape, device_count)) from err else: - if device_count % prod(device_mesh_shape): + if device_count % math.prod(device_mesh_shape): msg = "device mesh size {} does not divide available device count {}" - raise SkipTest(msg.format(prod(device_mesh_shape), device_count)) + raise SkipTest(msg.format(math.prod(device_mesh_shape), device_count)) else: return device_mesh_shape @@ -177,7 +178,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.sum(x, 0) ans = f(x) @@ -193,7 +194,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompile(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = f(x) lowered = f.lower(x) compiled = lowered.compile() @@ -210,7 +211,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompileInTreeMismatch(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f_exe = f.lower(x).compile() self.assertRaisesRegex( TypeError, "function compiled for .*, called with .*", @@ -235,7 +236,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompileArgTypeMismatch(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=int).reshape(shape) + x = np.arange(math.prod(shape), dtype=int).reshape(shape) x_f32 = x.astype(jnp.float32) x_i32 = x.astype(jnp.int32) f_exe = f.lower(x_f32).compile() @@ -250,7 +251,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompileMultiArg(self): f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = f(x, y) f_exe = f.lower(x, y).compile() ans = f_exe(x, y) @@ -267,7 +268,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerAsText(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x) self.assertIsInstance(f.as_text(), str) self.assertIsInstance(f.as_text(dialect='hlo'), str) @@ -277,7 +278,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompilerIR(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x) self.assertIsNotNone(f.compiler_ir()) self.assertIsNotNone(f.compiler_ir(dialect='hlo')) @@ -289,14 +290,14 @@ class PythonPmapTest(jtu.JaxTestCase): # TODO(frostig): remove (deprecated) f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x).compile() self.assertIsNotNone(f.compiler_ir()) def testLowerCompileAsText(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x).compile() self.assertIsInstance(f.as_text(), (str, type(None))) @@ -304,7 +305,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCostAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x) f.cost_analysis() # doesn't raise @@ -312,7 +313,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompileCostAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x).compile() f.cost_analysis() # doesn't raise @@ -320,21 +321,21 @@ class PythonPmapTest(jtu.JaxTestCase): def testLowerCompileMemoryAnalysis(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x).compile() f.memory_analysis() # doesn't raise def testLowerCompileExecutable(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f = f.lower(x).compile() self.assertIsNotNone(f.runtime_executable()) def testLowerShapedArray(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) x_shape = core.ShapedArray(x.shape, x.dtype) self.assertAllClose(f.lower(x_shape).compile()(x), f(x)) @@ -342,7 +343,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.broadcast_to(np.mean(x, 0), x.shape) ans = f(x) @@ -352,7 +353,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.array([x] * jax.device_count()) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) @@ -361,7 +362,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) x = (x % 2).astype(np.bool_) expected = np.array([x] * jax.device_count()) ans = f(x) @@ -371,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: lax.all_gather(x, 'i', axis=-1), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.array([x.T] * jax.device_count()) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) @@ -381,7 +382,7 @@ class PythonPmapTest(jtu.JaxTestCase): device_count = jax.device_count() shape = (device_count, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.array([x] * device_count).reshape(device_count, -1) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) @@ -392,7 +393,7 @@ class PythonPmapTest(jtu.JaxTestCase): device_count = jax.device_count() shape = (device_count, 4, 3) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.array([x.transpose(1, 0, 2).reshape(4, -1)] * device_count) ans = f(x) self.assertAllClose(ans, expected, check_dtypes=False) @@ -406,7 +407,7 @@ class PythonPmapTest(jtu.JaxTestCase): device_count = jax.device_count() shape = (4, device_count, device_count) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) self.assertAllClose(vmap(f)(x), jnp.stack([f(xs) for xs in x], axis=0)) def testReduceScatter(self): @@ -414,7 +415,7 @@ class PythonPmapTest(jtu.JaxTestCase): device_count = jax.device_count() shape = (device_count, device_count) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.sum(x, axis=0) ans = f(x) for i, actual in enumerate(ans): @@ -425,7 +426,7 @@ class PythonPmapTest(jtu.JaxTestCase): device_count = jax.device_count() shape = (device_count, 4 * device_count) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.sum(x, axis=0) ans = f(x) scatter_len = len(expected) // device_count @@ -444,7 +445,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, axis_name='i') shape = (replicas, 4 * replicas) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) group_1_result = np.sum(x[0::2,:], axis=0) @@ -504,7 +505,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i') shape = (jax.device_count(), 4 * 2) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape).view(np.complex64) expected = x - np.sum(x, 0) ans = f(x) @@ -566,7 +567,7 @@ class PythonPmapTest(jtu.JaxTestCase): return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis) shape = (jax.device_count(), 1, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) @@ -591,7 +592,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(self.pmap(f, 'i'), 'j') shape = mesh_shape + (4,) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) expected = x @@ -605,7 +606,7 @@ class PythonPmapTest(jtu.JaxTestCase): mesh_shape = (jax.device_count(),) shape = mesh_shape + (4,) x = np.array(3., dtype=np.float32) - y = np.arange(prod(shape), dtype=np.float32).reshape(shape) + y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) f_expected = np.broadcast_to(x, mesh_shape) f_ans = f(x, y) @@ -657,7 +658,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, axis_name='j', in_axes=(None, 0)) x = 3. - y = np.arange(prod(mesh_shape), dtype=np.float32).reshape(mesh_shape) + y = np.arange(math.prod(mesh_shape), dtype=np.float32).reshape(mesh_shape) expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape) ans = f(x, y) @@ -673,7 +674,7 @@ class PythonPmapTest(jtu.JaxTestCase): return jvp(jnp.ones_like(x)) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = np.cos(x) ans = splitjvp(x) @@ -687,7 +688,7 @@ class PythonPmapTest(jtu.JaxTestCase): return jnp.sin(x) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x) expected = grad(lambda x: jnp.sum(f(x)))(x) @@ -699,7 +700,7 @@ class PythonPmapTest(jtu.JaxTestCase): return lax.psum(x, axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.) def testGradOfJvp(self): @@ -714,7 +715,7 @@ class PythonPmapTest(jtu.JaxTestCase): fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1]) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = grad(lambda x: jnp.sum(splitjvp(x)))(x) expected = grad(fun)(x) @@ -730,7 +731,7 @@ class PythonPmapTest(jtu.JaxTestCase): return tot * jnp.ones_like(x) # broadcast to map like pjit does shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) y = 4 + x ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y) expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y) @@ -764,7 +765,7 @@ class PythonPmapTest(jtu.JaxTestCase): return grad(lambda w: jnp.sum(g(w)))(x) shape = mesh_shape + (4,) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = grad(lambda x: jnp.sum(test_fun(x)))(x) expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x) @@ -775,7 +776,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) # test that we can pass in and out ShardedDeviceArrays y = f(x) @@ -840,7 +841,7 @@ class PythonPmapTest(jtu.JaxTestCase): if jax.device_count() < max(in_shape[:1] + out_shape[:1]): raise SkipTest("not enough devices") - x = np.arange(prod(in_shape)).reshape(in_shape) + x = np.arange(math.prod(in_shape)).reshape(in_shape) sharded_x = self.pmap(lambda x: x)(x) self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape), check_dtypes=False) @@ -858,7 +859,7 @@ class PythonPmapTest(jtu.JaxTestCase): shape = (num_pairs, 2, 4) else: shape = (device_count, 1, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1) @@ -874,7 +875,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, 'i') shape = (replicas, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected_psum = 2. * replicas // 2 expected = x - expected_psum @@ -891,7 +892,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, 'i') shape = (replicas, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) def sum_helper(a): return np.broadcast_to(a.sum(0, keepdims=True), (len(a), x.shape[1])) @@ -913,7 +914,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, 'i') shape = (replicas, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) def sum_helper(a): return np.broadcast_to(a.sum(0, keepdims=True), (replicas // 2, x.shape[1])) @@ -938,7 +939,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, 'i') shape = (replicas, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) @@ -963,7 +964,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, 'i') shape = (replicas, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = f(x) @@ -999,7 +1000,7 @@ class PythonPmapTest(jtu.JaxTestCase): axis_index_groups=axis_index_groups) shape = (len(devices), 2 if axis_index_groups else jax.device_count()) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.) def testNestedPmapReplicaGroups(self): @@ -1014,7 +1015,7 @@ class PythonPmapTest(jtu.JaxTestCase): f3 = self.pmap(self.pmap(f, 'j'), 'i') shape = (2, replicas // 2, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) def sum_helper_f1(a): return np.broadcast_to(a.sum(1, keepdims=True), (shape[0], shape[1] // 2, shape[2])) @@ -1030,7 +1031,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) shape = (replicas // 2, 2, 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) def sum_helper_f3(a): return np.broadcast_to(a.sum(0, keepdims=True), (shape[0] // 2, shape[1], shape[2])) @@ -1197,7 +1198,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.max(x, 0) ans = f(x) @@ -1207,7 +1208,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.min(x, 0) ans = f(x) @@ -1291,7 +1292,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2, 3) - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants @@ -1330,7 +1331,7 @@ class PythonPmapTest(jtu.JaxTestCase): shuffle(devices) f = self.pmap(self.pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants @@ -1352,7 +1353,7 @@ class PythonPmapTest(jtu.JaxTestCase): raise SkipTest("error test doesn't apply with disable_jit") f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2 + 1, 3) - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) self.assertRaisesRegex( ValueError, (r"compiling computation that requires \d+ logical devices, " @@ -1363,7 +1364,7 @@ class PythonPmapTest(jtu.JaxTestCase): # if jax.device_count() > 1: # f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1]) # shape = (2, jax.device_count() // 2, 3) - # x = jnp.arange(prod(shape)).reshape(shape) + # x = jnp.arange(math.prod(shape)).reshape(shape) # self.assertRaisesRegex( # ValueError, # (r"compiling computation that requires \d+ replicas, " @@ -1392,7 +1393,7 @@ class PythonPmapTest(jtu.JaxTestCase): return g(x) shape = (device_count, 1, 4) - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) a, b, c = f(x) self.assertEqual(a.shape, shape[:-1]) @@ -1525,7 +1526,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testPswapaxes(self): device_count = jax.device_count() shape = (device_count, 3, device_count, 5) - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) ans = self.pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x) expected = np.swapaxes(x, 0, 2) @@ -1535,7 +1536,7 @@ class PythonPmapTest(jtu.JaxTestCase): def testGradOfPswapaxes(self): device_count = jax.device_count() shape = (device_count, 1, device_count) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) w = np.arange(device_count, dtype=np.float32) @partial(self.pmap, axis_name='i') @@ -1561,7 +1562,7 @@ class PythonPmapTest(jtu.JaxTestCase): if device_count % 2 != 0: raise SkipTest('test requires an even number of devices') shape = (device_count, device_count // 2) - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) axis_index_groups = np.arange(device_count, dtype=np.int32) axis_index_groups = axis_index_groups.reshape((device_count // 2, 2)).T @@ -1582,7 +1583,7 @@ class PythonPmapTest(jtu.JaxTestCase): if device_count % 2 != 0: raise SkipTest('test requires an even number of devices') shape = (device_count, device_count // 2, 1) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) w = np.arange(device_count, dtype=np.float32) axis_index_groups = np.arange(device_count, dtype=np.int32) @@ -1610,7 +1611,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = lambda x: x - lax.psum(x, 'i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.sum(x, 0) ans = jit(self.pmap(f, 'i'))(x) @@ -1656,7 +1657,7 @@ class PythonPmapTest(jtu.JaxTestCase): f = self.pmap(f, axis_name='i') shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) y = f(x) self.assertIsInstance(y, jax.Array) @@ -2363,7 +2364,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=jax.devices()) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) expected = x - np.sum(x, 0) ans = f(x) self.assertAllClose(ans, expected) @@ -2387,7 +2388,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): def testNoDevicesError(self): f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[]) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) with self.assertRaisesRegex( ValueError, "'devices' argument to pmap must be non-empty, or None."): f(x) @@ -2508,7 +2509,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): return jnp.sin(x) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x) expected = grad(lambda x: jnp.sum(f(x)))(x) @@ -2519,7 +2520,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase): def f(x, y): return jnp.sin(x + y()) shape = (jax.device_count(), 4) - x = np.arange(prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) y = lambda: 3. ans = f(x, y) @@ -2531,9 +2532,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase): def f(x, y): return jnp.sin(x + y) xshape = (2, jax.device_count(), 4) - x = np.arange(prod(xshape)).reshape(xshape) + x = np.arange(math.prod(xshape)).reshape(xshape) yshape = (2, 4, jax.device_count()) - y = np.arange(prod(yshape)).reshape(yshape) + y = np.arange(math.prod(yshape)).reshape(yshape) self.assertAllClose(f(x, y), jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1)))) @@ -2544,11 +2545,11 @@ class PmapWithDevicesTest(jtu.JaxTestCase): fp = pmap(f, in_axes=(1, 2, None)) fv = vmap(f, in_axes=(1, 2, None)) xshape = (5, jax.device_count(), 7) - x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape) + x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (5, 7, jax.device_count()) - y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape) + y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape) zshape = (5, 7) - z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape) + z = np.arange(math.prod(zshape), dtype=np.float32).reshape(zshape) dx, dy, dz = jax.grad(lambda args: fp(*args).sum())((x, y, z)) assert dx.shape == xshape @@ -2563,9 +2564,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase): def f(x, y): return jnp.sin(x + y), y * 2 xshape = (2, jax.device_count(), 4) - x = np.arange(prod(xshape)).reshape(xshape) + x = np.arange(math.prod(xshape)).reshape(xshape) yshape = (2, 4) - y = np.arange(prod(yshape)).reshape(yshape) + y = np.arange(math.prod(yshape)).reshape(yshape) self.assertAllClose(f(x, y), (jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2)) @@ -2606,9 +2607,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase): return f xshape = (5, 7) - x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape) + x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape) yshape = (5, jax.device_count(), 7) - y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape) + y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape) self.assertAllClose(jax.grad(mk_case(pmap))(x, y), jax.grad(mk_case(vmap))(x, y)) @@ -2624,7 +2625,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase): if jax.device_count() < shape[0]: raise SkipTest(f"requires {shape[0]} devices") - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) num_threads = 10 @@ -2650,7 +2651,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase): if jax.device_count() < shape[0]: raise SkipTest(f"requires {shape[0]} devices") - x = jnp.arange(prod(shape)).reshape(shape) + x = jnp.arange(math.prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) self.assertIsNone(sharded_x._npy_value) @@ -2941,7 +2942,7 @@ class ShardArgsTest(jtu.JaxTestCase): nshards = len(indices) if jax.device_count() < nshards: raise SkipTest - x = np.arange(prod(shape)).reshape(shape) + x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) bufs = pxla.shard_args(jax.devices()[:nshards], [indices], [arg]) self.assertEqual(len(bufs), 1) diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index a659f0f9b..3737bf718 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -14,6 +14,7 @@ from functools import partial +import math import numpy as np @@ -24,7 +25,6 @@ from jax import grad from jax._src import test_util as jtu from jax import dtypes from jax.scipy import ndimage as lsp_ndimage -from jax._src.util import prod from jax.config import config config.parse_flags_with_absl() @@ -80,7 +80,7 @@ class NdimageTest(jtu.JaxTestCase): mode, cval, impl, round_, rng_factory): def args_maker(): - x = np.arange(prod(shape), dtype=dtype).reshape(shape) + x = np.arange(math.prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index f9111d0fb..1459d32ca 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from functools import partial import itertools as it +import math import os from types import SimpleNamespace from typing import (Any, Sequence, Set, Iterable, Iterator, NamedTuple, @@ -31,7 +33,7 @@ from jax.sharding import PartitionSpec as P from jax._src import core from jax._src import test_util as jtu from jax._src import xla_bridge -from jax._src.util import safe_zip, safe_map, prod, partition_list, merge_lists +from jax._src.util import safe_zip, safe_map, partition_list, merge_lists import jax.numpy as jnp from jax.experimental.shard_map import shard_map @@ -552,13 +554,13 @@ def shmap_reference( def make_indexer(mesh: Mesh, spec: P, x: Any ) -> Callable[[Tuple[int, ...]], Tuple[slice, ...]]: - block_shape = [d // prod(mesh.shape[ax] for ax in (elt or ())) + block_shape = [d // math.prod(mesh.shape[ax] for ax in (elt or ())) for d, elt in zip(x.shape, spec)] def indexer(idx): starts = [0 if el is None else idx[list(mesh.shape).index(el)] if type(el) is not tuple else sum(idx[list(mesh.shape).index(el[i])] - * prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el))) + * math.prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el))) for el in spec] return tuple(slice(start * size, (start + 1) * size) for start, size in zip(starts, block_shape)) @@ -647,7 +649,7 @@ def make_in_spec(mesh: Mesh, in_type_base: ShapeDtypeDuck) -> Chooser: return new_type, partition_spec def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck: - new_shape = tuple(d * prod(mesh.shape[ax] for ax in (elt or ())) + new_shape = tuple(d * math.prod(mesh.shape[ax] for ax in (elt or ())) for d, elt in zip(shape.shape, spec)) return jax.ShapeDtypeStruct(new_shape, shape.dtype) diff --git a/tests/xmap_test.py b/tests/xmap_test.py index 97d1c5379..64031b535 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -14,6 +14,7 @@ import functools import itertools as it +import math import os import re from itertools import product, permutations @@ -45,7 +46,7 @@ from jax._src import config as jax_config from jax._src.nn import initializers as nn_initializers from jax._src import xla_bridge from jax._src.lib import xla_client -from jax._src.util import unzip2, prod, safe_zip +from jax._src.util import unzip2, safe_zip from jax._src.lax import parallel as lax_parallel from jax._src.lax.parallel import pgather from jax.interpreters import batching, pxla @@ -80,7 +81,7 @@ def tearDownModule(): def create_array(global_shape, global_mesh, mesh_axes, global_data=None): if global_data is None: global_data = np.arange( - prod(global_shape), dtype=np.float32).reshape(global_shape) + math.prod(global_shape), dtype=np.float32).reshape(global_shape) sharding = NamedSharding(global_mesh, mesh_axes) @@ -1118,7 +1119,7 @@ class XMapGDATest(XMapTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1146,7 +1147,7 @@ class XMapGDATest(XMapTestCase): global_input_shape = (8, 2) mesh_axes = P('x') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1185,7 +1186,7 @@ class XMapGDATest(XMapTestCase): global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1224,7 +1225,7 @@ class XMapGDATest(XMapTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) def cb(index): return input_data[index] @@ -1248,7 +1249,7 @@ class XMapGDATest(XMapTestCase): global_input_shape = (8, 2) mesh_axes = P('x', 'y') input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) + math.prod(global_input_shape)).reshape(global_input_shape) gda_obj = global_device_array.GlobalDeviceArray.from_callback( global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx]) with jax_config.parallel_functions_output_gda(True):