Replace jax._src.util.prod with math.prod.

math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
This commit is contained in:
Peter Hawkins 2023-02-28 12:40:30 -08:00 committed by jax authors
parent 4f48f94649
commit 8fb1fd318d
40 changed files with 311 additions and 283 deletions

View File

@ -18,13 +18,15 @@ python3 pmap_benchmark.py
To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
"""
import math
from absl import app
import jax
from jax import numpy as jnp
from jax import pmap
from jax.config import config
from jax._src.util import prod
from benchmarks import benchmark
@ -118,7 +120,7 @@ def sharded_device_array_indexing_benchmark():
nshards = min(8, jax.local_device_count())
shape = (nshards, 8, 8)
def benchmark_fn():
arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
arr = pmap(lambda x: x)(jnp.arange(math.prod(shape)).reshape(shape))
indices = indices_fn()
for idx in indices:
arr[idx]

View File

@ -26,6 +26,7 @@ import collections
import functools
from functools import partial
import inspect
import math
import typing
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
@ -65,7 +66,7 @@ from jax._src.lib import pmap_lib
from jax._src.sharding import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list,
wrap_name, cache, wraps, HashableFunction,
weakref_lru_cache)
@ -1034,7 +1035,7 @@ def xla_computation(fun: Callable,
if axis_env is None:
return xla.AxisEnv(nreps, (), ())
else:
nreps = nreps * prod(size for name, size in axis_env)
nreps = nreps * math.prod(size for name, size in axis_env)
names, sizes = unzip2(axis_env)
return xla.AxisEnv(nreps, names, sizes)
@ -3232,7 +3233,7 @@ class ShapeDtypeStruct:
self.sharding = sharding
self.named_shape = {} if named_shape is None else dict(named_shape)
size = property(lambda self: prod(self.shape))
size = property(lambda self: math.prod(self.shape))
ndim = property(lambda self: len(self.shape))
def __len__(self):

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import math
import operator as op
import numpy as np
import functools
@ -26,7 +27,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.config import config
from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
from jax._src.util import safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import api
@ -208,7 +209,7 @@ class ArrayImpl(basearray.Array):
@property
def size(self):
return prod(self.shape)
return math.prod(self.shape)
@property
def sharding(self):
@ -552,12 +553,13 @@ def make_array_from_callback(
Example:
>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
...
>>> input_shape = (8, 8)
>>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape)
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
...
@ -601,6 +603,7 @@ def make_array_from_single_device_arrays(
Example:
>>> import math
>>> from jax.sharding import Mesh
>>> from jax.sharding import PartitionSpec as P
>>> import numpy as np
@ -608,7 +611,7 @@ def make_array_from_single_device_arrays(
>>> shape = (8, 8)
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
>>> inp_data = np.arange(prod(shape)).reshape(shape)
>>> inp_data = np.arange(math.prod(shape)).reshape(shape)
...
>>> arrays = [
... jax.device_put(inp_data[index], d)

View File

@ -22,6 +22,7 @@ from functools import partial, partialmethod, total_ordering
import gc
import inspect
import itertools as it
import math
import operator
from operator import attrgetter
import threading
@ -44,7 +45,7 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, as_hashable_function,
HashableFunction, HashableWrapper, weakref_lru_cache)
import jax._src.pretty_printer as pp
@ -1481,7 +1482,7 @@ class ShapedArray(UnshapedArray):
return ShapedArray(shape, dtype, weak_type, named_shape)
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
size = property(lambda self: math.prod(self.shape))
broadcast: ClassVar[Optional[aval_method]] = None
transpose: ClassVar[Optional[aval_method]] = None
@ -1628,7 +1629,7 @@ class DShapedArray(UnshapedArray):
self.weak_type = weak_type
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
size = property(lambda self: math.prod(self.shape))
def str_short(self, short_dtypes=False) -> str:
del short_dtypes # ignored

View File

@ -15,6 +15,7 @@
from collections import Counter
import dataclasses
import functools
import math
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
@ -26,7 +27,7 @@ from jax._src.lib import xla_client as xc
from jax._src.config import config
from jax._src.interpreters import pxla
from jax.interpreters import xla, mlir
from jax._src.util import prod, safe_zip
from jax._src.util import safe_zip
from jax._src.interpreters.pxla import PartitionSpec
Shape = Tuple[int, ...]
@ -87,7 +88,7 @@ def _get_shard_indices_replica_ids_uncached(
out[device] = (index, replica_id)
shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes)
expected_unique_shards = prod(
expected_unique_shards = math.prod(
[g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0])
if expected_unique_shards != unique_shards:
raise RuntimeError(
@ -104,7 +105,7 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
if not mesh_axis:
chunk_size.append(size)
elif isinstance(mesh_axis, tuple):
m = prod([global_mesh.shape[ma] for ma in mesh_axis])
m = math.prod([global_mesh.shape[ma] for ma in mesh_axis])
chunk_size.append(size // m)
else:
chunk_size.append(size // global_mesh.shape[mesh_axis])
@ -323,7 +324,7 @@ class GlobalDeviceArray:
@property
def size(self):
return prod(self.shape)
return math.prod(self.shape)
@property
def mesh(self):
@ -455,7 +456,7 @@ class GlobalDeviceArray:
>>> global_input_shape = (8, 8)
>>> mesh_axes = P('x', 'y')
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(index):
... return global_input_data[index]
@ -505,7 +506,7 @@ class GlobalDeviceArray:
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x')
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def batched_cb(indices):
... assert len(indices) == len(global_mesh.local_devices)
@ -555,7 +556,7 @@ class GlobalDeviceArray:
>>> global_input_shape = (8, 2)
>>> mesh_axes = P(('x', 'y'))
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
...
>>> def cb(cb_inp):
... dbs = []

View File

@ -37,6 +37,7 @@ import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import operator as op
import sys
import threading
@ -78,7 +79,7 @@ from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
unzip2, HashableFunction)
@ -284,7 +285,8 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
if not has_unstacked:
op_sharding_proto = sharding_spec_sharding_proto(self)
return _op_sharding_to_numpy_indices(
op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape)
op_sharding_proto, shape, math.prod(self.mesh_shape)
).reshape(self.mesh_shape)
axis_indices: List[Sequence[Index]] = []
shard_indices_shape = []
@ -312,7 +314,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
# with each dimension having size equal to the number of shards across the corresponding
# logical array dimension, and each element containing the multi-dimensional index that
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_)
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(it.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices = shard_indices.reshape(shard_indices_shape)
@ -766,7 +768,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
@property
def size(self):
return prod(self.aval.shape)
return math.prod(self.aval.shape)
@property
def ndim(self):
@ -2252,7 +2254,7 @@ core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
def _unravel_index_hlo(axis_env):
div = mlir.ir_constant(
np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32))
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
return hlo.RemOp(
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
@ -4035,7 +4037,7 @@ class DynamicAxisEnv(list):
@property
def nreps(self):
return prod(frame.hard_size for frame in self)
return math.prod(frame.hard_size for frame in self)
class _ThreadLocalState(threading.local):
def __init__(self):

View File

@ -19,6 +19,7 @@ import dataclasses
import functools
from functools import partial
import itertools as it
import math
import operator
import re
from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
@ -37,7 +38,7 @@ from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.util import (prod, safe_zip, safe_map, partition_list)
from jax._src.util import (safe_zip, safe_map, partition_list)
from jax._src.typing import Shape
@ -311,7 +312,7 @@ def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
@ -326,10 +327,10 @@ def _axis_groups(mesh_spec, mesh_axes):
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(mesh_spec, mesh_axes)), -1))
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))

View File

@ -14,6 +14,7 @@
from functools import partial
import math
from typing import Union, Sequence
import numpy as np
@ -30,7 +31,6 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
from jax._src.util import prod
__all__ = [
"fft",
@ -150,7 +150,7 @@ def _irfft_transpose(t, fft_lengths):
full(2.0, shape=(n - 2 + is_odd,)),
full(1.0, shape=(1 - is_odd,))],
dimension=0)
scale = 1 / prod(fft_lengths)
scale = 1 / math.prod(fft_lengths)
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
# Use JAX's convention for complex gradients

View File

@ -17,6 +17,7 @@ import enum
import functools
from functools import partial
import itertools
import math
import operator
from typing import (Any, Callable, Optional, Sequence, Tuple, List,
TypeVar, Union, cast as type_cast, overload)
@ -70,7 +71,7 @@ from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import PmapSharding
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
split_list)
xb = xla_bridge
@ -1422,7 +1423,7 @@ def collapse(operand: Array, start_dimension: int,
collapsed (raveled) into a single dimension.
"""
lo, hi = start_dimension, stop_dimension
size = prod(operand.shape[lo:hi])
size = math.prod(operand.shape[lo:hi])
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
return reshape(operand, new_shape)

View File

@ -15,6 +15,7 @@
import inspect
import functools
from functools import partial
import math
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
import warnings
@ -50,7 +51,6 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike
from jax._src.util import prod
xops = xla_client.ops
@ -1313,7 +1313,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778")
a_aval, taus_aval = ctx.avals_out
*batch_dims, m, n = a_aval.shape
batch = prod(batch_dims)
batch = math.prod(batch_dims)
if batch == 0 or m == 0 or n == 0:
return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval)

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import prod
from jax._src.lax import lax
from jax._src.lax import convolution
@ -92,7 +92,7 @@ def conv_general_dilated_patches(
lhs_spec, rhs_spec, out_spec = dimension_numbers
spatial_size = prod(filter_shape)
spatial_size = math.prod(filter_shape)
n_channels = lhs_array.shape[lhs_spec[1]]
# Move separate `lhs` spatial locations into separate `rhs` channels.

View File

@ -17,6 +17,7 @@ Parallelization primitives.
from functools import partial
import itertools
import math
import string
from typing import Sequence, Union
import warnings
@ -40,7 +41,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.util import (
unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis)
unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis)
unsafe_map, map = map, safe_map # type: ignore
@ -828,7 +829,7 @@ def psum_bind(*args, axes, axis_index_groups):
assert not pos_axes
size = len(axis_index_groups[0])
else:
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
size = math.prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
return tuple(lax._const(x, size) * pos_reduce(x) for x in args)
return core.AxisPrimitive.bind(
psum_p, *args, axes=axes, axis_index_groups=axis_index_groups)
@ -1563,9 +1564,12 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
'`axis_index` translation rule does not support multiple axis names.')
axis_name, = axis_name
axis_pos = list(axis_env.names).index(axis_name)
nreplicas = axis_env.nreps // prod(axis_env.sizes)
div = mlir.ir_constant(np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
dtype=np.uint32))
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
div = mlir.ir_constant(
np.array(
nreplicas * math.prod(axis_env.sizes[axis_pos + 1 :]), dtype=np.uint32
)
)
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(axis_context,

View File

@ -17,7 +17,7 @@ Common neural network layer initializers, consistent with definitions
used in Keras and Sonnet.
"""
import math
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
import numpy as np
@ -27,7 +27,6 @@ from jax import lax
from jax import random
from jax._src import core
from jax._src import dtypes
from jax._src.util import prod
KeyArray = random.KeyArray
Array = Any
@ -549,7 +548,7 @@ def orthogonal(scale: RealNumeric = 1.0,
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis]
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
A = random.normal(key, matrix_shape, dtype)
Q, R = jnp.linalg.qr(A)

View File

@ -27,6 +27,7 @@ rules for the underlying :code:`lax` primitives.
import builtins
import collections
from functools import partial
import math
import operator
import types
from typing import (
@ -82,7 +83,7 @@ from jax._src.numpy.util import ( # noqa: F401
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip,
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl
@ -534,7 +535,7 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim]
xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip')
hist = bincount(xy, weights, length=_prod(nbins))
hist = bincount(xy, weights, length=math.prod(nbins))
hist = reshape(hist, nbins)
core = D*(slice(1, -1),)
hist = hist[core]
@ -914,7 +915,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
arr = ravel(a)
new_size = _prod(new_shape)
new_size = math.prod(new_shape)
if arr.size == 0 or new_size == 0:
return zeros_like(arr, shape=new_shape)

View File

@ -14,6 +14,7 @@
import builtins
from functools import partial
import math
import operator
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
import warnings
@ -31,7 +32,7 @@ from jax._src.numpy.util import (
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod)
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
_all = builtins.all
@ -371,7 +372,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = Non
if axis is None:
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
elif isinstance(axis, tuple):
weights_sum = lax.full_like(avg, _prod(core.dimension_as_value(a.shape[d]) for d in axis))
weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis))
else:
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
else:

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import math
import operator
from textwrap import dedent as _dedent
from typing import Optional, Tuple, Union
@ -31,7 +32,6 @@ from jax._src.numpy.lax_numpy import (
sort, where, zeros)
from jax._src.numpy.util import _check_arraylike, _wraps
from jax._src.typing import Array, ArrayLike
from jax._src.util import prod as _prod
_lax_const = lax_internal._const
@ -230,11 +230,11 @@ def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
# is fixed to match numpy.
aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
size, *out_shape = aux.shape
if _prod(out_shape) == 0:
if math.prod(out_shape) == 0:
size = 1
perm = zeros(1, dtype=int)
else:
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
perm = lexsort(aux.reshape(size, math.prod(out_shape)).T[::-1])
aux = aux[perm]
if aux.size:
if dtypes.issubdtype(aux.dtype, np.inexact):

View File

@ -15,6 +15,7 @@
import abc
from functools import partial, reduce
import math
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
@ -44,7 +45,7 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.sharding import (
NamedSharding, PmapSharding, GSPMDSharding)
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.util import canonicalize_axis, safe_map, safe_zip
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -1141,7 +1142,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape):
return _threefry_random_bits_original(key, bit_width, shape)
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64:
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
k1, k2 = key
@ -1160,7 +1161,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
@partial(jit, static_argnums=(1, 2), inline=True)
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
size = prod(shape)
size = math.prod(shape)
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
# polymorphism
max_count, r = divmod(bit_width * size, 32)

View File

@ -14,8 +14,9 @@
from functools import partial
from typing import Optional, Sequence, Union
import math
from operator import index
from typing import Optional, Sequence, Union
import warnings
import numpy as np
@ -40,7 +41,7 @@ from jax._src.numpy.lax_numpy import (
_arraylike, _check_arraylike, _convert_and_clip_integer,
_promote_dtypes_inexact)
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import prod, canonicalize_axis
from jax._src.util import canonicalize_axis
RealArray = ArrayLike
@ -507,7 +508,7 @@ def choice(key: KeyArray,
else:
axis = canonicalize_axis(axis, arr.ndim)
n_inputs = arr.shape[axis]
n_draws = prod(shape)
n_draws = math.prod(shape)
if n_draws == 0:
return jnp.zeros(shape, dtype=arr.dtype)
if n_inputs <= 0:
@ -1025,7 +1026,7 @@ def _gamma_grad(sample, a, *, log_space):
def _gamma_impl(key, a, *, log_space, use_vmap=False):
# split key to match the shape of a
a_shape = jnp.shape(a)
split_count = prod(a_shape[key.ndim:])
split_count = math.prod(a_shape[key.ndim:])
keys = key.flatten()
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
keys = keys.flatten()

View File

@ -14,6 +14,7 @@
from collections import namedtuple
from functools import partial
import math
from typing import Optional, Tuple
import jax
@ -24,7 +25,7 @@ from jax._src.api import vmap
from jax._src.numpy.lax_numpy import _check_arraylike
from jax._src.numpy.util import _wraps
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis, prod
from jax._src.util import canonicalize_axis
import scipy
@ -80,7 +81,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
axis = canonicalize_axis(axis, x.ndim)
x = jnp.moveaxis(x, axis, 0)
x = x.reshape(x.shape[0], prod(x.shape[1:]))
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
vals, counts = vmap(_mode_helper, in_axes=1)(x)
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))

View File

@ -14,6 +14,7 @@
"""Module for state types."""
from __future__ import annotations
import math
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from jax._src import core
@ -21,7 +22,7 @@ from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.util import safe_map, safe_zip, prod
from jax._src.util import safe_map, safe_zip
xc = xla_client
xb = xla_bridge
@ -89,7 +90,7 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
return AbstractRef(self.inner_aval.join(other.inner_aval))
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
size = property(lambda self: math.prod(self.shape))
@property
def shape(self):

View File

@ -17,6 +17,7 @@ import inspect
import io
import functools
from functools import partial
import math
import re
import os
import tempfile
@ -46,7 +47,7 @@ from jax._src.config import (flags, bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.util import prod, unzip2
from jax._src.util import unzip2
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
@ -668,7 +669,7 @@ def rand_int(rng, low=0, high=None):
def rand_unique_int(rng, high=None):
def fn(shape, dtype):
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
return rng.choice(np.arange(high or math.prod(shape), dtype=dtype),
size=shape, replace=False)
return fn
@ -755,7 +756,7 @@ def sample_product_testcases(*args, **kw):
"""Non-decorator form of sample_product."""
args = [list(arg) for arg in args]
kw = [(k, list(v)) for k, v in kw.items()]
n = prod(len(a) for a in args) * prod(len(v) for _, v in kw)
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
testcases = []
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)):
testcase = {}
@ -1054,7 +1055,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
"""Test utility for setting up meshes given mesh data from `schedules`."""
# This is similar to the `with_mesh` function above, but isn't a decorator.
axis_names, shape = unzip2(named_shape)
size = prod(shape)
size = math.prod(shape)
local_devices = list(api.local_devices())
if len(local_devices) < size:
raise unittest.SkipTest(f"Test requires {size} local devices")
@ -1094,7 +1095,7 @@ def restore_spmd_manual_lowering_flag():
config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag)
def create_global_mesh(mesh_shape, axis_names):
size = prod(mesh_shape)
size = math.prod(mesh_shape)
if len(api.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
devices = sorted(api.devices(), key=lambda d: d.id)

View File

@ -269,12 +269,6 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
"""
return xc.weakref_lru_cache(config._trace_context, call, maxsize)
def prod(xs):
out = 1
for x in xs:
out *= x
return out
class Unhashable:
__slots__ = ["val"]

View File

@ -83,6 +83,7 @@ TODO:
- Support RNNs other than LSTM.
"""
from functools import partial
import math
from typing import Any, Dict, List, Tuple
import jax
@ -92,7 +93,6 @@ from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.custom_derivatives import custom_vjp
from jax._src.typing import Array, Shape
from jax._src.util import prod
import jax.numpy as jnp
try:
from jax._src.lib import gpu_rnn
@ -161,7 +161,7 @@ def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int,
"""Get param count in LSTM."""
layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
bidirectional)
param_count = sum([prod(shape) for shape in layer_shapes])
param_count = sum([math.prod(shape) for shape in layer_shapes])
return param_count
@ -201,7 +201,7 @@ def unpack_lstm_weights(
for w_kind in [W_ih, W_hh]:
shape = flat_shapes[flat_shapes_offset]
flat_shapes_offset += 1
num_elems = prod(shape)
num_elems = math.prod(shape)
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
w_offsets += num_elems
@ -211,7 +211,7 @@ def unpack_lstm_weights(
for w_kind in [b_ih, b_hh]:
shape = flat_shapes[flat_shapes_offset]
flat_shapes_offset += 1
num_elems = prod(shape)
num_elems = math.prod(shape)
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
w_offsets += num_elems
return W_ih, W_hh, b_ih, b_hh

View File

@ -17,6 +17,7 @@ import enum
from functools import partial, lru_cache
import inspect
import itertools as it
import math
import operator as op
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
Set, Tuple, TypeVar, Union, Protocol)
@ -37,7 +38,7 @@ from jax._src import util
from jax._src.core import Tracer
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, fft, linalg)
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
from jax._src.util import (HashableFunction, unzip2, as_hashable_function,
memoize, partition_list, merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.interpreters import batching
@ -150,7 +151,7 @@ def _check_specs_vs_args(
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns)
for d, ns in names.items()) else no_fail
for a, names in zip(in_avals, in_names_flat)]
if any(f is not no_fail for f in fail):
@ -201,10 +202,10 @@ def _spec_divisibility_error(
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
names = _canonicalize_spec(spec)
for d, ns in names.items():
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = prod(mesh.shape[n] for n in ns)
sz = math.prod(mesh.shape[n] for n in ns)
msgs.append(
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
@ -382,7 +383,7 @@ pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
else:
raise NotImplementedError # TODO(mattjj): add table with handlers
@ -390,7 +391,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
if isinstance(aval, core.ShapedArray):
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)),
named_shape={k: v for k, v in aval.named_shape.items()
if k not in mesh.shape})
@ -880,7 +881,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
check_rep):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
for ns, x in zip(out_names, out_cts)]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))

View File

@ -14,11 +14,11 @@
"""Base JAX Sparse object."""
import abc
import math
from typing import Sequence, Tuple
import jax
from jax._src import core
from jax._src import util
from jax._src.typing import Array
@ -37,7 +37,7 @@ class JAXSparse(abc.ABC):
@property
def size(self) -> int:
return util.prod(self.shape)
return math.prod(self.shape)
@property
def ndim(self) -> int:

View File

@ -171,7 +171,6 @@ from jax._src.lax.lax import (
population_count_p as population_count_p,
pow as pow,
pow_p as pow_p,
prod as prod,
random_gamma_grad as random_gamma_grad,
random_gamma_grad_p as random_gamma_grad_p,
real as real,
@ -369,3 +368,5 @@ from jax.lax import linalg as linalg
from jax._src.pjit import with_sharding_constraint
from jax._src.dispatch import device_put_p
from math import prod # TODO(phawkins): remove this accidental export

View File

@ -13,9 +13,8 @@
# limitations under the License.
import functools
from functools import partial
import operator
import math
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
@ -60,8 +59,6 @@ def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
"""LU decomposition."""
@ -71,7 +68,7 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
batch = math.prod(batch_dims)
if batch > 1 and m == n and m // batch <= 128:
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
@ -118,7 +115,7 @@ def _geqrf_hlo(platform, gpu_solver, dtype, a):
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
batch = math.prod(batch_dims)
lwork, opaque = gpu_solver.build_geqrf_descriptor(
np.dtype(dtype), batch, m, n)
@ -156,7 +153,7 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
batch = math.prod(batch_dims)
lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
np.dtype(dtype), batch, m, n)
@ -220,7 +217,7 @@ def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
batch = math.prod(batch_dims)
tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2]
@ -266,7 +263,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
batch = math.prod(batch_dims)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if have_jacobi_solver and n <= 32:
@ -317,7 +314,7 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = _prod(batch_dims)
b = math.prod(batch_dims)
if ir.ComplexType.isinstance(a_type.element_type):
singular_vals_type = ir.ComplexType(a_type.element_type).element_type
else:

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import math
from absl.testing import absltest
@ -21,7 +22,6 @@ import numpy as np
import jax
from jax import lax
from jax._src import test_util as jtu
from jax._src.util import prod
from jax.config import config
@ -104,7 +104,7 @@ class AnnTest(jtu.JaxTestCase):
is_max_k=[True, False],
)
def test_autodiff(self, shape, dtype, k, is_max_k):
vals = np.arange(prod(shape), dtype=dtype)
vals = np.arange(math.prod(shape), dtype=dtype)
vals = self.rng().permutation(vals).reshape(shape)
if is_max_k:
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]

View File

@ -14,6 +14,7 @@
"""Tests for GlobalDeviceArray."""
import contextlib
import math
import os
import unittest
from absl.testing import absltest
@ -27,7 +28,7 @@ from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.util import prod, safe_zip
from jax._src.util import safe_zip
from jax.interpreters import pxla
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
@ -72,7 +73,7 @@ def tearDownModule():
def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(prod(shape)).reshape(shape)
global_data = np.arange(math.prod(shape)).reshape(shape)
return array.make_array_from_callback(
shape, sharding, lambda idx: global_data[idx]), global_data
@ -310,7 +311,7 @@ class JaxArrayTest(jtu.JaxTestCase):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
s = sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
di_map = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[di_map[d]], d)
for d in jax.local_devices()]
@ -333,7 +334,7 @@ class JaxArrayTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# sharding device ids = {0, 1}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {2, 3}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
with self.assertRaisesRegex(
@ -349,7 +350,7 @@ class JaxArrayTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
# Sharding device ids = {0, 1}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 0}
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
with self.assertRaisesRegex(
@ -366,7 +367,7 @@ class JaxArrayTest(jtu.JaxTestCase):
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
# sharding device ids = {1, 2}
s = sharding.NamedSharding(mesh, P('x'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# _arrays device ids = {0, 1}
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
with self.assertRaisesRegex(
@ -389,7 +390,7 @@ class JaxArrayTest(jtu.JaxTestCase):
shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
mps = sharding.NamedSharding(mesh, pspec)
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
str_expected_shard_shape = str(expected_shard_shape).replace(
r"(", r"\(").replace(r")", r"\)")
@ -403,7 +404,7 @@ class JaxArrayTest(jtu.JaxTestCase):
shape = (8, 2)
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
s = sharding.NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.int32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
indices = s.devices_indices_map(shape)
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
with self.assertRaisesRegex(
@ -725,7 +726,7 @@ class ShardingTest(jtu.JaxTestCase):
self.skipTest('Test needs >= 2 devices.')
shape = (2, 2)
num_elements = prod(shape)
num_elements = math.prod(shape)
inp_data = np.arange(num_elements).reshape(shape)
out = jax.pmap(lambda x: x)(inp_data)
self.assertIsInstance(out.sharding, sharding.PmapSharding)
@ -954,7 +955,7 @@ class RngShardingTest(jtu.JaxTestCase):
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
s = sharding.NamedSharding(mesh, pspec)
n = prod(global_shape)
n = math.prod(global_shape)
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])

View File

@ -14,6 +14,7 @@
from functools import partial
import hashlib
import math
import os
import random
import sys
@ -29,7 +30,6 @@ from jax.experimental.maps import xmap
from jax.experimental.pjit import pjit
import jax
from jax import jit, lax, pmap
from jax._src.util import prod
import jax._src.test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
@ -284,11 +284,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
return x + y
shape = (8, 8)
x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 1)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f(x, x + 1)
files_in_directory = len(os.listdir(tmpdir))
self.assertEqual(files_in_directory, 2)

View File

@ -13,15 +13,17 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""
from absl.testing import absltest
from absl.testing import parameterized
import math
import unittest
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax._src import core
from jax._src import test_util as jtu
from jax._src.util import prod, safe_zip
from jax._src.util import safe_zip
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
@ -34,7 +36,7 @@ config.parse_flags_with_absl()
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(prod(global_shape)).reshape(global_shape)
global_data = np.arange(math.prod(global_shape)).reshape(global_shape)
return GlobalDeviceArray.from_callback(
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
@ -75,7 +77,7 @@ class GDATest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -128,7 +130,7 @@ class GDATest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
global_input_shape = (8, 4, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -163,7 +165,7 @@ class GDATest(jtu.JaxTestCase):
expected_replica_ids):
global_mesh = jtu.create_global_mesh((8,), ('x'))
global_input_shape = (16,)
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
global_input_data = np.arange(math.prod(global_input_shape)).reshape(-1)
def cb(index):
return global_input_data[index]
@ -214,7 +216,7 @@ class GDATest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -240,7 +242,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(indices):
self.assertEqual(len(indices), len(global_mesh.local_devices))
@ -260,7 +262,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, 4)
@ -288,7 +290,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
@ -307,7 +309,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(None,)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
input_gda = GlobalDeviceArray.from_callback(
@ -340,7 +342,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
dbs = [
@ -358,7 +360,7 @@ class GDATest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(('x', 'y'))
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]

View File

@ -16,6 +16,7 @@
import collections
from functools import partial
import itertools
import math
from unittest import SkipTest
from absl.testing import absltest
@ -28,7 +29,6 @@ from jax import dtypes
from jax import lax
from jax._src import test_util as jtu
from jax.test_util import check_grads
from jax._src.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -832,7 +832,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(prod(shape), dtype=key_dtype)
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -847,7 +847,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
k=[1, 3],
)
def testTopKGrad(self, shape, dtype, k):
flat_values = np.arange(prod(shape), dtype=dtype)
flat_values = np.arange(math.prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

View File

@ -19,6 +19,7 @@ from functools import partial
import inspect
import io
import itertools
import math
from typing import cast, Iterator, Optional, List, Tuple
import unittest
from unittest import SkipTest
@ -46,7 +47,7 @@ from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import prod, safe_zip
from jax._src.util import safe_zip
from jax._src import array
from jax.config import config
@ -1358,7 +1359,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (prod(shape),)
cond_shape = (math.prod(shape),)
else:
cond_shape = (shape[axis],)
@ -1394,7 +1395,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (prod(shape),)
cond_shape = (math.prod(shape),)
else:
cond_shape = (shape[axis],)
@ -1488,7 +1489,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
[dict(shape=shape, axis=axis, idx=idx)
for shape in nonempty_nonscalar_array_shapes
for axis in [None] + list(range(-len(shape), len(shape)))
for idx in (range(-prod(shape), prod(shape))
for idx in (range(-math.prod(shape), math.prod(shape))
if axis is None else
range(-shape[axis], shape[axis]))],
dtype=all_dtypes,
@ -3372,7 +3373,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
idx_shape=all_shapes,
)
def testUnravelIndex(self, shape, idx_shape, dtype):
size = prod(shape)
size = math.prod(shape)
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)
def np_fun(index, shape):

View File

@ -15,6 +15,7 @@ from __future__ import annotations
from functools import partial
import itertools
import math
import operator
import types
import unittest
@ -44,7 +45,6 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src import lax_reference
from jax._src.util import prod
from jax._src.lax import lax as lax_internal
from jax._src.internal_test_util import lax_test_util
@ -1998,7 +1998,7 @@ class LaxTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(prod(shape), dtype=key_dtype)
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -2035,7 +2035,7 @@ class LaxTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(prod(shape), dtype=key_dtype)
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -2051,7 +2051,7 @@ class LaxTest(jtu.JaxTestCase):
)
def testTopK(self, shape, dtype, k):
def args_maker():
flat_values = np.arange(prod(shape), dtype=dtype)
flat_values = np.arange(math.prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
return [values]
def reference_top_k(x):

View File

@ -15,6 +15,7 @@
from functools import partial
import itertools
import math
from typing import Optional, cast
import unittest
@ -32,7 +33,7 @@ from jax._src import test_util as jtu
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.util import prod, safe_map, safe_zip
from jax._src.util import safe_map, safe_zip
from jax.config import config
config.parse_flags_with_absl()
@ -667,7 +668,7 @@ class LaxVmapTest(jtu.JaxTestCase):
# Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
# values a bfloat16 can represent exactly to avoid ties.
def testTopK(self, shape, dtype, k, bdims):
rng = jtu.rand_int(self.rng(), high=prod(shape))
rng = jtu.rand_int(self.rng(), high=math.prod(shape))
# _CheckBatching doesn't work with tuple outputs, so test outputs separately.
op1 = lambda x: lax.top_k(x, k=k)[0]
self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)

View File

@ -16,6 +16,7 @@ import os
import re
from functools import partial, lru_cache
import logging
import math
import threading
import unittest
from collections import OrderedDict, namedtuple
@ -51,7 +52,7 @@ from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.util import prod, curry, unzip2, safe_zip
from jax._src.util import curry, unzip2, safe_zip
from jax.config import config
config.parse_flags_with_absl()
@ -85,7 +86,7 @@ def create_gda(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):
if global_data is None:
global_data = np.arange(
prod(global_shape), dtype=dtype).reshape(global_shape)
math.prod(global_shape), dtype=dtype).reshape(global_shape)
if isinstance(mesh_axes, Sharding):
mesh_axes = mesh_axes.spec
@ -98,7 +99,7 @@ def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
dtype=np.float32):
if global_data is None:
global_data = np.arange(
prod(global_shape), dtype=dtype).reshape(global_shape)
math.prod(global_shape), dtype=dtype).reshape(global_shape)
if isinstance(mesh_axes, Sharding):
sharding = mesh_axes
@ -144,7 +145,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x
shape = (2, 2)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
actual = f(x)
expected = x
self.assertAllClose(actual, expected, check_dtypes=False)
@ -164,7 +165,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x + y
shape = (8, 8)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -182,7 +183,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x + y
shape = (8, 8)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
if config.jax_array:
out = jax.jit(f)(x, x + 1)
self.assertArraysEqual(out, x + x + 1)
@ -205,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return jnp.pad(out, [[0, 1]])
shape = (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
@ -222,7 +223,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x + y
shape = (8, 8)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
with jtu.create_global_mesh((2,), ('x')) as mesh:
actual = f(x, x + 1)
expected = x + (x + 1)
@ -281,7 +282,7 @@ class PJitTest(jtu.BufferDonationTestCase):
def testMeshDecorator(self):
x = jnp.arange(8)
mesh_shape = (2, 2)
size = prod(mesh_shape)
size = math.prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
@ -347,7 +348,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return y * 2
shape = (8, 8)
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -374,7 +375,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return y * 2
shape = (8, 8)
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -402,7 +403,7 @@ class PJitTest(jtu.BufferDonationTestCase):
y = with_sharding_constraint(y, ops)
return y * 2
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
expected = (x + 1) * 2
actual = f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -426,7 +427,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x
shape = (8, 8)
v = np.arange(prod(shape)).reshape(shape)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{"a": v, "b": v * 2}, v * 3]
actual = f(x)
@ -458,7 +459,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x
shape = (8, 8)
v = np.arange(prod(shape)).reshape(shape)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{"a": v, "b": v * 2}, v * 3]
actual = f(x)
@ -487,7 +488,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x
shape = (2, 8, 8)
v = np.arange(prod(shape)).reshape(shape)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{'a': v, 'b': v * 2}, v * 3]
actual = f(x)
@ -513,7 +514,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x
shape = (2, 8, 8)
v = np.arange(prod(shape)).reshape(shape)
v = np.arange(math.prod(shape)).reshape(shape)
x = [{'a': v, 'b': v * 2}, v * 3]
mlir_str = str(f.lower(x).compiler_ir())
@ -1084,7 +1085,7 @@ class PJitTest(jtu.BufferDonationTestCase):
input_shape = (8, 4)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
seeds = jnp.arange(
prod(input_shape), dtype=np.uint32).reshape(input_shape)
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
with mesh:
def make_keys(seeds):
@ -1179,7 +1180,7 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1214,7 +1215,7 @@ class GDAPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1288,7 +1289,7 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1321,7 +1322,7 @@ class GDAPjitTest(jtu.JaxTestCase):
@jtu.with_mesh([('x', 4), ('y', 2)])
def test_pjit_gda_non_gda_inputs(self):
input_shape = (8, 2)
input_data = np.arange(prod(input_shape)).reshape(input_shape)
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
with parallel_functions_output_gda(True):
@partial(pjit,
@ -1353,7 +1354,8 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -1374,7 +1376,8 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x')
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -1400,7 +1403,7 @@ class GDAPjitTest(jtu.JaxTestCase):
input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(input_shape), dtype=np.float32).reshape(input_shape)
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
def cb(index):
return input_data[index]
@ -1439,7 +1442,8 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(None)
global_input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32
).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
@ -1500,7 +1504,7 @@ class GDAPjitTest(jtu.JaxTestCase):
global_input_shape = (8, 2)
mesh_axes = P(None)
global_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
with parallel_functions_output_gda(True):
f = pjit(lambda x: x, in_shardings=mesh_axes, out_shardings=mesh_axes)
@ -1584,7 +1588,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
raise unittest.SkipTest('GDA and Array cannot be together.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with parallel_functions_output_gda(True):
with global_mesh:
@ -1610,7 +1614,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with jax_array(True):
with global_mesh:
@ -1634,7 +1638,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with ctx(True):
with global_mesh:
@ -1662,7 +1666,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (4, 2)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with global_mesh:
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO)
@ -1685,7 +1689,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
in_resource = pspec
@ -1718,7 +1722,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
global_input_shape = (8, 4)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
in_resource = NamedSharding(global_mesh, pspec)
@ -1748,7 +1752,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
with jax_array(True):
with global_mesh:
@ -1835,7 +1839,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_data = np.arange(
prod(input_shape), dtype=np.float32).reshape(input_shape)
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
with jax_array(True):
with global_mesh:
f = pjit(lambda x: x,
@ -1854,7 +1858,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_shape = (8, 2)
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
input_data = np.arange(
prod(input_shape), dtype=np.float32).reshape(input_shape)
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
with jax_array(True):
with global_mesh:
f = pjit(
@ -2026,7 +2030,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_globally_sharded_key_array_result_8x4_single_device(self):
input_shape = (8, 4)
seeds = jnp.arange(
prod(input_shape), dtype=np.uint32).reshape(input_shape)
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
@pjit
def make_keys(seeds):
@ -2080,7 +2084,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
spec = P('x', 'y')
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
with jax_array(True):
with m1:
@ -2096,7 +2100,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
spec = P('x', 'y')
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
with jax_array(True):
with m1:
@ -2250,7 +2254,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
shape = (8, 2)
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# Explicitly put on the ordering of devices which does not match the mesh
# ordering to make sure we reorder them in the constructor and the output
@ -2271,7 +2275,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_not_xlacompatible_sharding_error(self):
shape = (8, 2)
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
ts = TempSharding(jax.devices())
arr = array.make_array_from_callback(
shape, ts, lambda idx: inp_data[idx])
@ -2333,7 +2337,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_pjit_uncommitted_array_and_committed_array(self):
shape = (8, 2)
uarr = jnp.arange(prod(shape), dtype=np.float32).reshape(shape)
uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
carr, inp_data = create_array(shape, mesh, P('x', 'y'))
with mesh:
@ -2357,7 +2361,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_pjit_uncommitted_array_multi_devices(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
arr = array.ArrayImpl(
core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
@ -2602,7 +2606,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
shape = (8, 2)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
arr = jax.device_put(inp_data, s)
out = pjit(lambda x: x)(arr)
self.assertArraysEqual(out, inp_data)
@ -2640,7 +2644,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_multi_device_pjit_mul(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y')))
@ -2655,7 +2659,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_single_device_pjit_cpp_dispatch(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((1,), ('x',))
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None)
with jtu.count_pjit_cache_miss() as count:
@ -3379,7 +3383,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
def test_with_sharding_constraint_spmd_axis_name(self):
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
shape = (8, 4, 2, 2)
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
def f(inp):
return with_sharding_constraint(inp, P('data', None, None))
@ -3723,7 +3727,7 @@ class PJitErrorTest(jtu.JaxTestCase):
def test_pjit_with_deleted_input_at_first_call(self, committed):
shape = (8,)
mesh = jtu.create_global_mesh((1,), ('x',))
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
if committed:
s = NamedSharding(mesh, P('x',))
x = jax.device_put(inp_data, s)
@ -3742,7 +3746,7 @@ class PJitErrorTest(jtu.JaxTestCase):
def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
shape = (8,)
mesh = jtu.create_global_mesh((1,), ('x',))
inp_data = np.arange(prod(shape)).reshape(shape)
inp_data = np.arange(math.prod(shape)).reshape(shape)
if committed:
s = NamedSharding(mesh, P('x',))
x = jax.device_put(inp_data, s)

View File

@ -17,6 +17,7 @@ from concurrent.futures import ThreadPoolExecutor
from functools import partial
import itertools as it
import gc
import math
import os
from random import shuffle
from typing import Optional, cast
@ -44,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
from jax._src import config as jax_config
from jax._src import device_array
from jax._src import xla_bridge
from jax._src.util import prod, safe_map, safe_zip
from jax._src.util import safe_map, safe_zip
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import array
@ -119,7 +120,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
aval = ShapedArray(input_shape, dtype)
if input_data is None:
input_data = np.arange(prod(input_shape)).reshape(input_shape)
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
@ -167,9 +168,9 @@ class PythonPmapTest(jtu.JaxTestCase):
msg = "device mesh shape {} not compatible with device count {}"
raise SkipTest(msg.format(device_mesh_shape, device_count)) from err
else:
if device_count % prod(device_mesh_shape):
if device_count % math.prod(device_mesh_shape):
msg = "device mesh size {} does not divide available device count {}"
raise SkipTest(msg.format(prod(device_mesh_shape), device_count))
raise SkipTest(msg.format(math.prod(device_mesh_shape), device_count))
else:
return device_mesh_shape
@ -177,7 +178,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
@ -193,7 +194,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompile(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = f(x)
lowered = f.lower(x)
compiled = lowered.compile()
@ -210,7 +211,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompileInTreeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f_exe = f.lower(x).compile()
self.assertRaisesRegex(
TypeError, "function compiled for .*, called with .*",
@ -235,7 +236,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompileArgTypeMismatch(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=int).reshape(shape)
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
f_exe = f.lower(x_f32).compile()
@ -250,7 +251,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompileMultiArg(self):
f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = f(x, y)
f_exe = f.lower(x, y).compile()
ans = f_exe(x, y)
@ -267,7 +268,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerAsText(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
@ -277,7 +278,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
@ -289,14 +290,14 @@ class PythonPmapTest(jtu.JaxTestCase):
# TODO(frostig): remove (deprecated)
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.compiler_ir())
def testLowerCompileAsText(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
@ -304,7 +305,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x)
f.cost_analysis() # doesn't raise
@ -312,7 +313,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompileCostAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
f.cost_analysis() # doesn't raise
@ -320,21 +321,21 @@ class PythonPmapTest(jtu.JaxTestCase):
def testLowerCompileMemoryAnalysis(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
f.memory_analysis() # doesn't raise
def testLowerCompileExecutable(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f = f.lower(x).compile()
self.assertIsNotNone(f.runtime_executable())
def testLowerShapedArray(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
x_shape = core.ShapedArray(x.shape, x.dtype)
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
@ -342,7 +343,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
ans = f(x)
@ -352,7 +353,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -361,7 +362,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
x = (x % 2).astype(np.bool_)
expected = np.array([x] * jax.device_count())
ans = f(x)
@ -371,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: lax.all_gather(x, 'i', axis=-1), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x.T] * jax.device_count())
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -381,7 +382,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
shape = (device_count, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x] * device_count).reshape(device_count, -1)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -392,7 +393,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
shape = (device_count, 4, 3)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.array([x.transpose(1, 0, 2).reshape(4, -1)] * device_count)
ans = f(x)
self.assertAllClose(ans, expected, check_dtypes=False)
@ -406,7 +407,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
shape = (4, device_count, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
self.assertAllClose(vmap(f)(x), jnp.stack([f(xs) for xs in x], axis=0))
def testReduceScatter(self):
@ -414,7 +415,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
shape = (device_count, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
for i, actual in enumerate(ans):
@ -425,7 +426,7 @@ class PythonPmapTest(jtu.JaxTestCase):
device_count = jax.device_count()
shape = (device_count, 4 * device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.sum(x, axis=0)
ans = f(x)
scatter_len = len(expected) // device_count
@ -444,7 +445,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, axis_name='i')
shape = (replicas, 4 * replicas)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
group_1_result = np.sum(x[0::2,:], axis=0)
@ -504,7 +505,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4 * 2)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
expected = x - np.sum(x, 0)
ans = f(x)
@ -566,7 +567,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
shape = (jax.device_count(), 1, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
@ -591,7 +592,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(self.pmap(f, 'i'), 'j')
shape = mesh_shape + (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = x
@ -605,7 +606,7 @@ class PythonPmapTest(jtu.JaxTestCase):
mesh_shape = (jax.device_count(),)
shape = mesh_shape + (4,)
x = np.array(3., dtype=np.float32)
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
f_expected = np.broadcast_to(x, mesh_shape)
f_ans = f(x, y)
@ -657,7 +658,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, axis_name='j', in_axes=(None, 0))
x = 3.
y = np.arange(prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
y = np.arange(math.prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape)
ans = f(x, y)
@ -673,7 +674,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return jvp(jnp.ones_like(x))
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = np.cos(x)
ans = splitjvp(x)
@ -687,7 +688,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return jnp.sin(x)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
expected = grad(lambda x: jnp.sum(f(x)))(x)
@ -699,7 +700,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return lax.psum(x, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
def testGradOfJvp(self):
@ -714,7 +715,7 @@ class PythonPmapTest(jtu.JaxTestCase):
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
expected = grad(fun)(x)
@ -730,7 +731,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return tot * jnp.ones_like(x) # broadcast to map like pjit does
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
y = 4 + x
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
@ -764,7 +765,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return grad(lambda w: jnp.sum(g(w)))(x)
shape = mesh_shape + (4,)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
@ -775,7 +776,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
# test that we can pass in and out ShardedDeviceArrays
y = f(x)
@ -840,7 +841,7 @@ class PythonPmapTest(jtu.JaxTestCase):
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
raise SkipTest("not enough devices")
x = np.arange(prod(in_shape)).reshape(in_shape)
x = np.arange(math.prod(in_shape)).reshape(in_shape)
sharded_x = self.pmap(lambda x: x)(x)
self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape),
check_dtypes=False)
@ -858,7 +859,7 @@ class PythonPmapTest(jtu.JaxTestCase):
shape = (num_pairs, 2, 4)
else:
shape = (device_count, 1, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
@ -874,7 +875,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected_psum = 2. * replicas // 2
expected = x - expected_psum
@ -891,7 +892,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(len(a), x.shape[1]))
@ -913,7 +914,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(replicas // 2, x.shape[1]))
@ -938,7 +939,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
@ -963,7 +964,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, 'i')
shape = (replicas, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = f(x)
@ -999,7 +1000,7 @@ class PythonPmapTest(jtu.JaxTestCase):
axis_index_groups=axis_index_groups)
shape = (len(devices), 2 if axis_index_groups else jax.device_count())
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
def testNestedPmapReplicaGroups(self):
@ -1014,7 +1015,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f3 = self.pmap(self.pmap(f, 'j'), 'i')
shape = (2, replicas // 2, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper_f1(a):
return np.broadcast_to(a.sum(1, keepdims=True),
(shape[0], shape[1] // 2, shape[2]))
@ -1030,7 +1031,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
shape = (replicas // 2, 2, 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def sum_helper_f3(a):
return np.broadcast_to(a.sum(0, keepdims=True),
(shape[0] // 2, shape[1], shape[2]))
@ -1197,7 +1198,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.max(x, 0)
ans = f(x)
@ -1207,7 +1208,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.min(x, 0)
ans = f(x)
@ -1291,7 +1292,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
@ -1330,7 +1331,7 @@ class PythonPmapTest(jtu.JaxTestCase):
shuffle(devices)
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
shape = (2, len(devices) // 2, 3)
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
ans = f(x)
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
@ -1352,7 +1353,7 @@ class PythonPmapTest(jtu.JaxTestCase):
raise SkipTest("error test doesn't apply with disable_jit")
f = self.pmap(self.pmap(lambda x: 3))
shape = (2, jax.device_count() // 2 + 1, 3)
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
self.assertRaisesRegex(
ValueError,
(r"compiling computation that requires \d+ logical devices, "
@ -1363,7 +1364,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# if jax.device_count() > 1:
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
# shape = (2, jax.device_count() // 2, 3)
# x = jnp.arange(prod(shape)).reshape(shape)
# x = jnp.arange(math.prod(shape)).reshape(shape)
# self.assertRaisesRegex(
# ValueError,
# (r"compiling computation that requires \d+ replicas, "
@ -1392,7 +1393,7 @@ class PythonPmapTest(jtu.JaxTestCase):
return g(x)
shape = (device_count, 1, 4)
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
a, b, c = f(x)
self.assertEqual(a.shape, shape[:-1])
@ -1525,7 +1526,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testPswapaxes(self):
device_count = jax.device_count()
shape = (device_count, 3, device_count, 5)
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
ans = self.pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
expected = np.swapaxes(x, 0, 2)
@ -1535,7 +1536,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testGradOfPswapaxes(self):
device_count = jax.device_count()
shape = (device_count, 1, device_count)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
w = np.arange(device_count, dtype=np.float32)
@partial(self.pmap, axis_name='i')
@ -1561,7 +1562,7 @@ class PythonPmapTest(jtu.JaxTestCase):
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2)
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
axis_index_groups = np.arange(device_count, dtype=np.int32)
axis_index_groups = axis_index_groups.reshape((device_count // 2, 2)).T
@ -1582,7 +1583,7 @@ class PythonPmapTest(jtu.JaxTestCase):
if device_count % 2 != 0:
raise SkipTest('test requires an even number of devices')
shape = (device_count, device_count // 2, 1)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
w = np.arange(device_count, dtype=np.float32)
axis_index_groups = np.arange(device_count, dtype=np.int32)
@ -1610,7 +1611,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = lambda x: x - lax.psum(x, 'i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = jit(self.pmap(f, 'i'))(x)
@ -1656,7 +1657,7 @@ class PythonPmapTest(jtu.JaxTestCase):
f = self.pmap(f, axis_name='i')
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
y = f(x)
self.assertIsInstance(y, jax.Array)
@ -2363,7 +2364,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
devices=jax.devices())
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
expected = x - np.sum(x, 0)
ans = f(x)
self.assertAllClose(ans, expected)
@ -2387,7 +2388,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def testNoDevicesError(self):
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
with self.assertRaisesRegex(
ValueError, "'devices' argument to pmap must be non-empty, or None."):
f(x)
@ -2508,7 +2509,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
return jnp.sin(x)
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
expected = grad(lambda x: jnp.sum(f(x)))(x)
@ -2519,7 +2520,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def f(x, y):
return jnp.sin(x + y())
shape = (jax.device_count(), 4)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
y = lambda: 3.
ans = f(x, y)
@ -2531,9 +2532,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def f(x, y):
return jnp.sin(x + y)
xshape = (2, jax.device_count(), 4)
x = np.arange(prod(xshape)).reshape(xshape)
x = np.arange(math.prod(xshape)).reshape(xshape)
yshape = (2, 4, jax.device_count())
y = np.arange(prod(yshape)).reshape(yshape)
y = np.arange(math.prod(yshape)).reshape(yshape)
self.assertAllClose(f(x, y),
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
@ -2544,11 +2545,11 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
fp = pmap(f, in_axes=(1, 2, None))
fv = vmap(f, in_axes=(1, 2, None))
xshape = (5, jax.device_count(), 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, 7, jax.device_count())
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape)
zshape = (5, 7)
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
z = np.arange(math.prod(zshape), dtype=np.float32).reshape(zshape)
dx, dy, dz = jax.grad(lambda args: fp(*args).sum())((x, y, z))
assert dx.shape == xshape
@ -2563,9 +2564,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
def f(x, y):
return jnp.sin(x + y), y * 2
xshape = (2, jax.device_count(), 4)
x = np.arange(prod(xshape)).reshape(xshape)
x = np.arange(math.prod(xshape)).reshape(xshape)
yshape = (2, 4)
y = np.arange(prod(yshape)).reshape(yshape)
y = np.arange(math.prod(yshape)).reshape(yshape)
self.assertAllClose(f(x, y),
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
@ -2606,9 +2607,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
return f
xshape = (5, 7)
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape)
yshape = (5, jax.device_count(), 7)
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape)
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
jax.grad(mk_case(vmap))(x, y))
@ -2624,7 +2625,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
if jax.device_count() < shape[0]:
raise SkipTest(f"requires {shape[0]} devices")
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
num_threads = 10
@ -2650,7 +2651,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
if jax.device_count() < shape[0]:
raise SkipTest(f"requires {shape[0]} devices")
x = jnp.arange(prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
self.assertIsNone(sharded_x._npy_value)
@ -2941,7 +2942,7 @@ class ShardArgsTest(jtu.JaxTestCase):
nshards = len(indices)
if jax.device_count() < nshards:
raise SkipTest
x = np.arange(prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
arg = make_arg(x)
bufs = pxla.shard_args(jax.devices()[:nshards], [indices], [arg])
self.assertEqual(len(bufs), 1)

View File

@ -14,6 +14,7 @@
from functools import partial
import math
import numpy as np
@ -24,7 +25,6 @@ from jax import grad
from jax._src import test_util as jtu
from jax import dtypes
from jax.scipy import ndimage as lsp_ndimage
from jax._src.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -80,7 +80,7 @@ class NdimageTest(jtu.JaxTestCase):
mode, cval, impl, round_, rng_factory):
def args_maker():
x = np.arange(prod(shape), dtype=dtype).reshape(shape)
x = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape]
if round_:
coords = [c.round().astype(int) for c in coords]

View File

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import itertools as it
import math
import os
from types import SimpleNamespace
from typing import (Any, Sequence, Set, Iterable, Iterator, NamedTuple,
@ -31,7 +33,7 @@ from jax.sharding import PartitionSpec as P
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.util import safe_zip, safe_map, prod, partition_list, merge_lists
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
import jax.numpy as jnp
from jax.experimental.shard_map import shard_map
@ -552,13 +554,13 @@ def shmap_reference(
def make_indexer(mesh: Mesh, spec: P, x: Any
) -> Callable[[Tuple[int, ...]], Tuple[slice, ...]]:
block_shape = [d // prod(mesh.shape[ax] for ax in (elt or ()))
block_shape = [d // math.prod(mesh.shape[ax] for ax in (elt or ()))
for d, elt in zip(x.shape, spec)]
def indexer(idx):
starts = [0 if el is None else
idx[list(mesh.shape).index(el)] if type(el) is not tuple else
sum(idx[list(mesh.shape).index(el[i])]
* prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
* math.prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
for el in spec]
return tuple(slice(start * size, (start + 1) * size)
for start, size in zip(starts, block_shape))
@ -647,7 +649,7 @@ def make_in_spec(mesh: Mesh, in_type_base: ShapeDtypeDuck) -> Chooser:
return new_type, partition_spec
def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck:
new_shape = tuple(d * prod(mesh.shape[ax] for ax in (elt or ()))
new_shape = tuple(d * math.prod(mesh.shape[ax] for ax in (elt or ()))
for d, elt in zip(shape.shape, spec))
return jax.ShapeDtypeStruct(new_shape, shape.dtype)

View File

@ -14,6 +14,7 @@
import functools
import itertools as it
import math
import os
import re
from itertools import product, permutations
@ -45,7 +46,7 @@ from jax._src import config as jax_config
from jax._src.nn import initializers as nn_initializers
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.util import unzip2, prod, safe_zip
from jax._src.util import unzip2, safe_zip
from jax._src.lax import parallel as lax_parallel
from jax._src.lax.parallel import pgather
from jax.interpreters import batching, pxla
@ -80,7 +81,7 @@ def tearDownModule():
def create_array(global_shape, global_mesh, mesh_axes, global_data=None):
if global_data is None:
global_data = np.arange(
prod(global_shape), dtype=np.float32).reshape(global_shape)
math.prod(global_shape), dtype=np.float32).reshape(global_shape)
sharding = NamedSharding(global_mesh, mesh_axes)
@ -1118,7 +1119,7 @@ class XMapGDATest(XMapTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1146,7 +1147,7 @@ class XMapGDATest(XMapTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1185,7 +1186,7 @@ class XMapGDATest(XMapTestCase):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
global_input_shape = (8, 2)
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1224,7 +1225,7 @@ class XMapGDATest(XMapTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return input_data[index]
@ -1248,7 +1249,7 @@ class XMapGDATest(XMapTestCase):
global_input_shape = (8, 2)
mesh_axes = P('x', 'y')
input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
math.prod(global_input_shape)).reshape(global_input_shape)
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
with jax_config.parallel_functions_output_gda(True):