Change np.prod->math.prod

Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
This commit is contained in:
Jake VanderPlas 2023-04-13 11:48:11 -07:00
parent fb46d3d084
commit 5521423d92
32 changed files with 133 additions and 119 deletions

View File

@ -15,6 +15,7 @@
import enum
import functools
import math
import operator
import google_benchmark
@ -54,7 +55,7 @@ def required_devices(num_devices_required):
def create_mesh(shape, axis_names, state):
size = np.prod(shape)
size = math.prod(shape)
if len(jax.devices()) < size:
state.skip_with_error(f"Requires {size} devices")
return None
@ -421,7 +422,7 @@ def sda_index_8(state):
def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
size = math.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(
@ -460,7 +461,7 @@ def sparse_bcoo_fromdense_compile(state):
def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False):
shape = (2000, 2000)
nse = 10000
size = np.prod(shape)
size = math.prod(shape)
rng = np.random.RandomState(1701)
data = rng.randn(nse)
indices = np.unravel_index(
@ -600,7 +601,7 @@ def bench_addressable_shards_index(state):
if mesh is None:
return
shape = (8, 2)
inp = np.arange(np.prod(shape)).reshape(shape)
inp = np.arange(math.prod(shape)).reshape(shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
arr = jax.device_put(inp, s)
@ -614,7 +615,7 @@ def bench_addressable_shards_replica_id(state):
if mesh is None:
return
shape = (64, 32)
inp = np.arange(np.prod(shape)).reshape(shape)
inp = np.arange(math.prod(shape)).reshape(shape)
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
arr = jax.device_put(inp, s)
@ -785,7 +786,7 @@ def pjit_aot_4000_device(state):
def host_local_array_to_global_array(state):
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
input_shape = (8, 2)
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
in_pspec = jax.sharding.PartitionSpec('x', 'y')
while state:

View File

@ -1869,8 +1869,8 @@ class DimensionHandler:
Raise InconclusiveDimensionOperation if there is no such integer for all
contexts,
"""
sz1 = int(np.prod(s1))
sz2 = int(np.prod(s2))
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if sz1 == 0 and sz2 == 0:
return 1
if sz1 % sz2:

View File

@ -1624,10 +1624,10 @@ def manual_proto(
tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])
tad_shape = [1] * aval.ndim
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in replicated_axes], dtype=int)))
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in manual_axes], dtype=int)))
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
raw_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = tad_shape

View File

@ -1256,7 +1256,7 @@ mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu')
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
m = lu.shape[0]
x = jnp.reshape(b, (m, np.prod(b.shape[1:])))
x = jnp.reshape(b, (m, math.prod(b.shape[1:])))
if trans == 0:
x = x[permutation, :]
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True)

View File

@ -14,6 +14,7 @@
import enum
from functools import partial
import math
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import weakref
@ -1891,7 +1892,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
# on the IDs.
ids_shape = np.array(updates.shape, dtype=np.int64)
ids_shape[dnums.update_window_dims,] = 1
num_ids = np.prod(ids_shape)
num_ids = math.prod(ids_shape)
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
lax._ones(updates, dtype=id_dtype))

View File

@ -19,6 +19,7 @@ from collections import OrderedDict, abc
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
NamedTuple, Union, Sequence, Mapping)
from functools import wraps, partial, partialmethod, lru_cache
import math
from jax import lax
from jax import numpy as jnp
@ -1573,9 +1574,9 @@ def _get_axis_resource_count(
if local_res_shape is None:
nlocal = None
else:
nlocal = int(np.prod(map(local_res_shape.get, resources), dtype=np.int64))
nlocal = math.prod(map(local_res_shape.get, resources))
resource_count_map[axis] = ResourceCount(
int(np.prod(map(global_res_shape.get, resources), dtype=np.int64)),
math.prod(map(global_res_shape.get, resources)),
nlocal, distributed)
return resource_count_map

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import collections
import contextlib
import functools
import math
import threading
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
@ -199,7 +200,7 @@ class Mesh(contextlib.ContextDecorator):
@property
def size(self):
return np.prod(list(self.shape.values()))
return math.prod(self.shape.values())
@property
def empty(self):

View File

@ -163,15 +163,15 @@ def _compute_fans(shape: core.NamedShape,
if isinstance(in_axis, int):
in_size = shape[in_axis]
else:
in_size = int(np.prod([shape[i] for i in in_axis]))
in_size = math.prod([shape[i] for i in in_axis])
if isinstance(out_axis, int):
out_size = shape[out_axis]
else:
out_size = int(np.prod([shape[i] for i in out_axis]))
out_size = math.prod([shape[i] for i in out_axis])
if isinstance(batch_axis, int):
batch_size = shape[batch_axis]
else:
batch_size = int(np.prod([shape[i] for i in batch_axis]))
batch_size = math.prod([shape[i] for i in batch_axis])
receptive_field_size = shape.total / in_size / out_size / batch_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size

View File

@ -747,7 +747,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
axis = _canonicalize_axis(-1, a.ndim)
else:
axis = _canonicalize_axis(axis, a.ndim)

View File

@ -1043,7 +1043,7 @@ def iota_2x32_shape(shape):
Setting aside representation, this function essentially computes the
equivalent of::
jax.lax.iota(dtype=np.uint64, size=np.prod(shape)).reshape(shape)
jax.lax.iota(dtype=np.uint64, size=math.prod(shape)).reshape(shape)
However:
@ -1069,7 +1069,7 @@ def iota_2x32_shape(shape):
[ 8, 9, 10, 11]], dtype=uint32)]
>>> def reshaped_iota(shape):
... return lax.iota(size=np.prod(shape), dtype=np.uint32).reshape(shape)
... return lax.iota(size=math.prod(shape), dtype=np.uint32).reshape(shape)
...
>>> reshaped_iota((3, 4))
Array([[ 0, 1, 2, 3],

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import math
import operator
from typing import Callable, Optional, Tuple, Union, Sequence
import warnings
@ -232,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
else:
step = nperseg - noverlap
batch_shape = list(batch_shape)
x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis]
x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis]
result = jax.lax.conv_general_dilated_patches(
x, (nperseg,), (step,),
'VALID',
@ -579,7 +580,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
raise ValueError('Input must have (..., frames, frame_length) shape.')
*batch_shape, nframes, segment_len = x.shape
flat_batchsize = np.prod(batch_shape, dtype=np.int64)
flat_batchsize = math.prod(batch_shape)
x = x.reshape((flat_batchsize, nframes, segment_len))
output_size = step_size * (nframes - 1) + segment_len
nstep_per_segment = 1 + (segment_len - 1) // step_size

View File

@ -75,7 +75,7 @@ def _sharding_spec_mesh_shape(self):
def get_logical_mesh_ids(mesh_shape):
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
_MeshAxisName = Any
@ -123,7 +123,7 @@ def sharding_spec_sharding_proto(
assert mesh_shape[maxis] == nchunks
mesh_permutation.append(maxis)
next_sharded_axis += 1
new_mesh_shape.append(int(np.prod(sharding.chunks)))
new_mesh_shape.append(math.prod(sharding.chunks))
elif isinstance(sharding, Unstacked):
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
else:
@ -187,7 +187,7 @@ def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
axis_indices.append(range(axis_size))
shard_indices_shape.append(axis_size)
elif isinstance(sharding, Chunked):
total_chunks = int(np.prod(sharding.chunks))
total_chunks = math.prod(sharding.chunks)
shard_size, ragged = divmod(axis_size, total_chunks)
assert not ragged, (axis_size, total_chunks, dim)
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)

View File

@ -487,7 +487,7 @@ def rand_fullrange(rng, standardize_nans=False):
"""Random numbers that span the full range of available bits."""
def gen(shape, dtype, post=lambda x: x):
dtype = np.dtype(dtype)
size = dtype.itemsize * np.prod(_dims_of_shape(shape), dtype=int)
size = dtype.itemsize * math.prod(_dims_of_shape(shape))
vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
vals = post(vals).view(dtype)
if shape is PYTHON_SCALAR_SHAPE:

View File

@ -498,6 +498,7 @@ import atexit
import functools
import itertools
import logging
import math
import threading
import traceback
from typing import (Any, Callable, Dict, List, Optional, Sequence,
@ -1380,7 +1381,7 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
def _aval_is_empty(aval) -> bool:
return np.prod(aval.shape) == 0
return math.prod(aval.shape) == 0
def _instantiate_zeros(tan, arg):
"""Turn special ad.zero tangents into arrays of 0s for sending to host.

View File

@ -15,6 +15,7 @@
import builtins
import dataclasses
from functools import partial, wraps
import math
import string
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
@ -665,7 +666,7 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
# TODO(marcvanzee): This may give very large deviations on TPU when using
# floats as inputs. Alternatively, we could implement this using a
# convolution with an all-1's kernel.
return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions))
return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions))
def _reduce_window(*args, jaxpr, consts, window_dimensions,
@ -951,7 +952,7 @@ def _gather_generate_indices(shape: Tuple[int, ...]):
"""
Returns the indices of the according to `shape`:
each element in the output is the index of an element of an array
of the provided shape. The result's shape is (np.prod(shape), len(shape))
of the provided shape. The result's shape is (math.prod(shape), len(shape))
For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1)
"""

View File

@ -14,6 +14,7 @@
"""Provides JAX and TensorFlow interoperation APIs."""
from functools import partial
import contextlib
import math
import operator
import os
import re
@ -3094,7 +3095,7 @@ def split_to_logical_devices(tensor: TfVal,
# of _shard_values.
if partition_dimensions is None:
return xla_sharding.replicate(tensor, use_sharding_op=True)
num_partition_splits = np.prod(partition_dimensions)
num_partition_splits = math.prod(partition_dimensions)
tile_assignment = np.arange(num_partition_splits).reshape(
partition_dimensions)
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)

View File

@ -304,7 +304,7 @@ class _DimMon(dict):
assert a_l <= a_u
bounds.append((a_l ** exp, a_u ** exp))
candidates = [np.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
return (min(*candidates), max(*candidates)) # type: ignore
@ -696,8 +696,8 @@ class DimensionHandlerPoly(core.DimensionHandler):
return _ensure_poly(d1, "ge") >= d2
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
sz1 = np.prod(s1)
sz2 = np.prod(s2)
sz1 = math.prod(s1)
sz2 = math.prod(s2)
if core.symbolic_equal_dim(sz1, sz2): # Takes care also of sz1 == sz2 == 0
return 1
err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"

View File

@ -62,6 +62,7 @@ import dataclasses
import datetime
from functools import partial
import itertools
import math
import os
import re
import sys
@ -346,7 +347,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
@staticmethod
def eigh_input(shape, dtype):
# In order to keep inputs small, we construct the input programmatically
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
# Make operand self-adjoint
operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2.
return operand
@ -417,7 +418,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
@staticmethod
def qr_harness(shape, dtype):
# In order to keep inputs small, we construct the input programmatically
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
return lax.linalg.qr(operand, full_matrices=True)
@parameterized.named_parameters(
@ -458,7 +459,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
@staticmethod
def lu_harness(shape, dtype):
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
return lax.linalg.lu(operand)
def test_tpu_Lu(self):

View File

@ -16,6 +16,7 @@
Specific JAX primitive conversion tests are in primitives_test."""
import collections
import contextlib
import math
import os
import re
from typing import Callable, Dict, Optional, Tuple
@ -1286,7 +1287,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
return carry
shape = (3, 2)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
jax_comp = jax.xla_computation(f_while)(x)
backend = jax._src.xla_bridge.get_backend()
@ -1723,7 +1724,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
return x + y
shape = (8, 10)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
in_axis_resources = (P("x"), P("x"))
out_axis_resources = None
res_jax = pjit.pjit(

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for the shape-polymorphic jax2tf conversion."""
import contextlib
import math
import unittest
from absl.testing import absltest, parameterized
@ -1280,7 +1281,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
# https://github.com/google/jax/issues/7093
# Also issue #6975.
x_shape = (2, 3, 4)
xi = np.arange(np.prod(x_shape), dtype=np.int16).reshape(x_shape)
xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape)
yf = xi.astype(np.float32)
xi_yf = (xi, yf)
zb = np.array([True, False], dtype=np.bool_)
@ -1350,7 +1351,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
f_tf = tf.function(f_tf, autograph=False)
x_shape = (2, 3, 4)
x = np.arange(np.prod(x_shape), dtype=np.int32).reshape(x_shape)
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
# When saving the model with gradients, we trace the gradient function
# and we used to get an error when creating zeros_like_aval for a
@ -1378,7 +1379,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
jax2tf.convert(lambda x: jnp.reshape(x, (np.prod(x.shape),)),
jax2tf.convert(lambda x: jnp.reshape(x, (math.prod(x.shape),)),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]),

View File

@ -628,7 +628,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
mesh = Mesh(self.devices, axis_names=('x'))
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
@partial(pjit.pjit,
in_shardings=(P('x', None),), out_shardings=P(None, 'x'))
@ -680,7 +680,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
return lax.ppermute(b, 'x', perm=perm)
with mesh:
a = np.arange(np.prod(4 * 4)).reshape((4, 4))
a = np.arange(4 * 4).reshape((4, 4))
res_jax = f_jax(a)
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
b0, b1 = b1, b0
@ -705,7 +705,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
if poly is not None:
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
mesh = Mesh(self.devices, axis_names=('x'))
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
@partial(pjit.pjit,
in_shardings=(P('x', None),), out_shardings=P('x', None))

View File

@ -1575,7 +1575,7 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
n = current_n_dense - n_dense
@partial(nfold_vmap, N=current_n_batch + 1)
def _update(d, i):
new_d = d.reshape(np.prod(d.shape[:n]), *d.shape[n:])
new_d = d.reshape(math.prod(d.shape[:n]), *d.shape[n:])
meshes = jnp.meshgrid(*(jnp.arange(s, dtype=i.dtype) for s in d.shape[:n]),
indexing='ij')
new_i = jnp.column_stack([jnp.broadcast_to(i, (new_d.shape[0], i.size)),
@ -1583,10 +1583,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
return new_d, new_i
new_data, new_indices = _update(new_data, new_indices)
new_data = new_data.reshape(*new_data.shape[:current_n_batch],
np.prod(new_data.shape[current_n_batch:current_n_batch + 2]),
math.prod(new_data.shape[current_n_batch:current_n_batch + 2]),
*new_data.shape[current_n_batch + 2:])
new_indices = new_indices.reshape(*new_indices.shape[:current_n_batch],
np.prod(new_indices.shape[current_n_batch: current_n_batch + 2]),
math.prod(new_indices.shape[current_n_batch: current_n_batch + 2]),
*new_indices.shape[current_n_batch + 2:])
current_n_dense = n_dense
@ -1595,10 +1595,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
@partial(nfold_vmap, N=n_batch)
def _update(d, i):
nse = i.shape[-2]
new_d = d.reshape(np.prod(d.shape[:n + 1]), *d.shape[n + 1:])
new_d = d.reshape(math.prod(d.shape[:n + 1]), *d.shape[n + 1:])
meshes = jnp.meshgrid(*(jnp.arange(d, dtype=i.dtype) for d in (*i.shape[:n], nse)),
indexing='ij')
new_i = i.reshape(np.prod(i.shape[:n + 1]), *i.shape[n + 1:])
new_i = i.reshape(math.prod(i.shape[:n + 1]), *i.shape[n + 1:])
new_i = jnp.column_stack((*(m.ravel() for m in meshes[:-1]), new_i))
return new_d, new_i
new_data, new_indices = _update(new_data, new_indices)
@ -1669,7 +1669,7 @@ def _bcoo_broadcast_in_dim(data: Array, indices: Array, *, spinfo: SparseInfo, s
new_n_dense = props.n_dense and len(shape) - min(broadcast_dimensions[-props.n_dense:])
new_n_sparse = len(shape) - new_n_batch - new_n_dense
if np.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != np.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
if math.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != math.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
raise NotImplementedError("Adding sparse dimensions with lengths != 1")
new_data, new_indices = data, indices
@ -1784,8 +1784,8 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in
batch_shape, sparse_shape, dense_shape = split_list(mat.shape, [mat.n_batch, mat.n_sparse])
batch_perm, sparse_perm, dense_perm = _validate_permutation(
mat.data, mat.indices, dimensions or tuple(range(mat.ndim)), mat.shape)
batch_size = np.prod(batch_shape, dtype=int)
sparse_size = np.prod(sparse_shape, dtype=int)
batch_size = math.prod(batch_shape)
sparse_size = math.prod(sparse_shape)
cuml_shape = np.cumprod(new_sizes)
if batch_size != 1 and batch_size not in cuml_shape:
@ -1797,7 +1797,7 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in
i1 = cuml_shape.searchsorted(batch_size, side='right')
i2 = cuml_shape.searchsorted(batch_size * sparse_size, side='right')
new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [i1, i2])
new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [int(i1), int(i2)])
# Reshape batch & dense dimensions: this is accomplished via a standard reshape.
data = lax.reshape(
@ -1949,7 +1949,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)
new_nse = int(np.prod(new_shape_sparse))
new_nse = math.prod(new_shape_sparse)
if mat.nse > new_nse:
new_data, new_indices = _bcoo_sum_duplicates(
new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse)
@ -2029,8 +2029,8 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
new_data = jnp.where(keep_data, new_data, 0)
if mat.nse > np.prod(size_sparse):
new_nse = int(np.prod(size_sparse))
if mat.nse > math.prod(size_sparse):
new_nse = math.prod(size_sparse)
new_data, new_indices = _bcoo_sum_duplicates(
new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse)
@ -2099,7 +2099,7 @@ def _bcoo_reduce_sum(data: Array, indices: Array, *, spinfo: SparseInfo, axes: S
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
new_nse = nse * math.prod([data.shape[i] for i in batch_axes])
data = lax.reshape(data,
(*new_batch_shape, new_nse, *data.shape[n_batch + 1:]),

View File

@ -908,7 +908,7 @@ class ShardingTest(jtu.JaxTestCase):
self.skipTest('Test needs >= 4 devices.')
ps = jax.sharding.PmapSharding.default(shape, sharded_dim)
inp = jnp.arange(np.prod(shape)).reshape(shape)
inp = jnp.arange(math.prod(shape)).reshape(shape)
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings

View File

@ -13,6 +13,7 @@
# limitations under the License.
from absl.testing import absltest
import math
import unittest
import numpy as np
@ -206,7 +207,7 @@ mlir.register_lowering(
def make_sparse_array(rng, shape, dtype, nnz=0.2):
mat = rng(shape, dtype)
size = int(np.prod(shape))
size = math.prod(shape)
if 0 < nnz < 1:
nnz = nnz * size
nnz = int(nnz)

View File

@ -3805,16 +3805,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis):
rng = jtu.rand_default(self.rng())
i_shape = np.array(i_shape)
i_shape = list(i_shape)
if axis is None:
i_shape = [np.prod(i_shape, dtype=np.int64)]
i_shape = [math.prod(i_shape)]
else:
# Test the case where the size of the axis doesn't necessarily broadcast.
i_shape[axis] *= 3
i_shape = list(i_shape)
def args_maker():
x = rng(x_shape, dtype)
n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis]
n = math.prod(x_shape) if axis is None else x_shape[axis]
if np.issubdtype(index_dtype, np.unsignedinteger):
index_rng = jtu.rand_int(self.rng(), 0, n)
else:

View File

@ -641,7 +641,7 @@ class LaxTest(jtu.JaxTestCase):
if c == 'N':
self.assertEqual(out_c, patch_c)
elif c == 'C':
self.assertEqual(out_c * np.prod(filter_shape), patch_c)
self.assertEqual(out_c * math.prod(filter_shape), patch_c)
else:
self.assertEqual(out_c, patch_c * filter_shape[filter_spec.index(c)])

View File

@ -226,8 +226,8 @@ class PJitTest(jtu.BufferDonationTestCase):
x_shape = (8, 6, 4)
y_shape = (4, 2)
x = jnp.arange(np.prod(x_shape)).reshape(x_shape)
y = jnp.arange(np.prod(y_shape)).reshape(y_shape)
x = jnp.arange(math.prod(x_shape)).reshape(x_shape)
y = jnp.arange(math.prod(y_shape)).reshape(y_shape)
actual = f(x, y)
expected = x @ y
self.assertAllClose(actual, expected, check_dtypes=False)
@ -285,7 +285,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
actual = f(x, x + 1)
expected = x @ (x + 1)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -753,7 +753,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x + y + z + w
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
y = x * 2.
z = x * 3.
w = x * 4.
@ -822,7 +822,7 @@ class PJitTest(jtu.BufferDonationTestCase):
token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
return x
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
def _dispatch():
with jax.sharding.Mesh(devices, ['d']):
@ -863,7 +863,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
expected = x @ (x + 1)
lowered = f.lower(x, x + 1)
@ -896,7 +896,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
exe = f.lower(x, x + 1, a=1, b=2).compile()
out = exe(x, x + 1, a=1, b=2)
self.assertArraysEqual(out, x @ (x + 1))
@ -910,7 +910,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
exe = f.lower(x, x + 1).compile()
self.assertRaisesRegex(
@ -926,7 +926,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
x_f32 = x.astype(jnp.float32)
x_i32 = x.astype(jnp.int32)
exe = f.lower(x_f32, x_f32).compile()
@ -947,7 +947,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
@ -963,7 +963,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
@ -995,7 +995,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1).compile()
self.assertIsNotNone(f.compiler_ir())
@ -1008,7 +1008,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1).compile()
self.assertIsInstance(f.as_text(), (str, type(None)))
@ -1022,7 +1022,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1)
f.cost_analysis() # doesn't raise
@ -1036,7 +1036,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1).compile()
f.cost_analysis() # doesn't raise
@ -1050,7 +1050,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1).compile()
f.memory_analysis() # doesn't raise
@ -1063,7 +1063,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = f.lower(x, x + 1).compile()
self.assertIsNotNone(f.runtime_executable())
@ -1088,7 +1088,7 @@ class PJitTest(jtu.BufferDonationTestCase):
shape = (8, 8)
aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
exe = f.lower(aval, x).compile()
self.assertIsInstance(exe, stages.Compiled)
self.assertArraysEqual(exe(x, x), x @ x)
@ -2158,7 +2158,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
return x if c == 0 else x + 1
shape = (8, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
exe = f.lower(1, x).compile()
self.assertAllClose(exe(x), x + 1, check_dtypes=False)
@ -3205,7 +3205,7 @@ class PJitErrorTest(jtu.JaxTestCase):
def testNonDivisibleArgs(self, mesh, resources):
x = jnp.ones((3, 2))
spec = P(resources, None)
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
mesh_size = str(math.prod([dim[1] for dim in mesh]))
error = re.compile(
r"One of pjit arguments.*" + spec_regex(spec) + r".*"
r"implies that the global size of its dimension 0 should be "
@ -3218,7 +3218,7 @@ class PJitErrorTest(jtu.JaxTestCase):
def testNonDivisibleOuts(self, mesh, resources):
x = jnp.ones((3, 2))
spec = P(resources, None)
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
mesh_size = str(math.prod([dim[1] for dim in mesh]))
error = re.compile(
r"One of pjit outputs.*" + spec_regex(spec) + r".*"
r"implies that the global size of its dimension 0 should be "
@ -3456,7 +3456,7 @@ class PJitErrorTest(jtu.JaxTestCase):
return x
return h(x)
xshape = (2, 5, 6)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
x = jnp.arange(math.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(RuntimeError,
"Changing the physical mesh is not allowed.*"):
f(x)
@ -3506,7 +3506,7 @@ class UtilTest(jtu.JaxTestCase):
FakeDevice = namedtuple('FakeDevice', ['id'])
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
mesh_axes, mesh_shape = unzip2(mesh_named_shape.items())
devices = [FakeDevice(i) for i in range(np.prod(list(mesh_shape)))]
devices = [FakeDevice(i) for i in range(math.prod(mesh_shape))]
mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))
dims = 5

View File

@ -581,7 +581,7 @@ class PythonPmapTest(jtu.JaxTestCase):
def testAllToAll(self, split_axis, concat_axis):
pmap_in_axis = 0
shape = (jax.device_count(),) * 3
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
def f(x):
@ -602,7 +602,7 @@ class PythonPmapTest(jtu.JaxTestCase):
raise SkipTest("test requires at least four devices")
pmap_in_axis = 0
shape = (4, 4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(math.prod(shape)).reshape(shape)
@partial(self.pmap, axis_name='i')
@partial(self.pmap, axis_name='j')
@ -2400,7 +2400,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
verify_ref()
shape = (jax.device_count(),) * 5
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
self.assertAllClose(pmap(vmap(f, in_axes=vmap_axis), axis_name='i')(x),
reference(x, split_axis, concat_axis, vmap_axis))
@ -2413,7 +2413,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
shape = (jax.device_count(),) * 4
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
self.assertAllClose(pmap(f, axis_name='i')(x),
vmap(f, axis_name='i')(x))
@ -2431,7 +2431,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
return lax.all_to_all(x, axes, split_axis=split_axis, concat_axis=concat_axis)
shape = (2, 2, 4, 4, 4)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
@ -2456,7 +2456,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
if jax.device_count() < 4:
raise SkipTest("test requires at least four devices")
shape = (4, 4, 8)
x = jnp.arange(np.prod(shape)).reshape(shape)
x = jnp.arange(math.prod(shape)).reshape(shape)
f = partial(prim, axis_name='i', tiled=tiled)
self.assertAllClose(vmap(f, axis_name='i')(x), pmap(f, axis_name='i')(x))

View File

@ -14,6 +14,7 @@
from functools import partial
import math
from unittest import SkipTest, skipIf
from typing import Any, Tuple, NamedTuple, Optional
import zlib
@ -689,7 +690,7 @@ class LaxRandomTest(jtu.JaxTestCase):
for ndim in [1 if is_range else len(input_range_or_shape)]
for axis in range(-ndim, ndim or 1)
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
if replace or np.prod(shape) <= ninputs
if replace or math.prod(shape) <= ninputs
],
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
weighted=[True, False],
@ -702,7 +703,7 @@ class LaxRandomTest(jtu.JaxTestCase):
key = self.seed_prng(0)
is_range = type(input_range_or_shape) is int
x = (input_range_or_shape if is_range else
self.rng().permutation(np.arange(np.prod(
self.rng().permutation(np.arange(math.prod(
input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape))
N = x if is_range else x.shape[axis]
if weighted:
@ -718,7 +719,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertEqual(np_shape, sample.shape)
if not replace and shape:
def lsort(x):
if not np.prod(x.shape): return x
if not math.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
@ -741,7 +742,7 @@ class LaxRandomTest(jtu.JaxTestCase):
is_range = type(range_or_shape) is int
x = (range_or_shape if is_range else
self.rng().permutation(np.arange(
np.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
math.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
shape = ((range_or_shape,) if is_range else range_or_shape)
x_ = np.copy(x)
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
@ -750,7 +751,7 @@ class LaxRandomTest(jtu.JaxTestCase):
self.assertFalse(np.all(perm == x)) # seems unlikely!
arr = np.arange(x) if is_range else x
def lsort(x):
if not np.prod(x.shape): return x
if not math.prod(x.shape): return x
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
return jnp.take(x, ind, axis)
if not independent:
@ -1602,7 +1603,7 @@ class KeyArrayTest(jtu.JaxTestCase):
def make_keys(self, *shape, seed=None):
if seed is None:
seed = 28
seeds = seed + jnp.arange(np.prod(shape), dtype=jnp.uint32)
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
return jnp.reshape(jax.vmap(make_key)(seeds), shape)

View File

@ -15,6 +15,7 @@
import contextlib
from functools import partial
import itertools
import math
import operator
import random
import unittest
@ -174,7 +175,7 @@ all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
def _rand_sparse(shape, dtype, nse=nse):
rand = rand_method(rng)
size = np.prod(shape).astype(int)
size = math.prod(shape)
if 0 <= nse < 1:
nse = nse * size
nse = min(size, int(nse))
@ -1554,7 +1555,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertEqual(mat.n_dense, out.n_dense)
# Unnecessary padding eliminated
max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
self.assertLessEqual(out.nse, max_nse)
@jtu.sample_product(
@ -1589,7 +1590,7 @@ class BCOOTest(sptu.SparseTestCase):
self.assertEqual(mat.n_dense, out.n_dense)
# Unnecessary padding eliminated
max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
self.assertLessEqual(out.nse, max_nse)
@jtu.sample_product(
@ -1644,7 +1645,7 @@ class BCOOTest(sptu.SparseTestCase):
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, nse=nse)
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
for layout in iter_sparse_layouts(shape)
for nse in [None, np.prod(shape) - 1]
for nse in [None, math.prod(shape) - 1]
],
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
remove_zeros=[True, False],
@ -2106,7 +2107,7 @@ class BCOOTest(sptu.SparseTestCase):
M = rng(shape, dtype)
def make_bcoo(M):
return sparse_bcoo._bcoo_fromdense(M, nse=np.prod(M.shape[:-1], dtype=int), n_dense=1)
return sparse_bcoo._bcoo_fromdense(M, nse=math.prod(M.shape[:-1]), n_dense=1)
todense = partial(sparse_bcoo._bcoo_todense, spinfo=sparse_util.SparseInfo(shape))
@ -2586,7 +2587,7 @@ class SparseObjectTest(sptu.SparseTestCase):
self.assertIsInstance(M, Obj)
self.assertEqual(M.shape, shape)
self.assertEqual(M.size, np.prod(shape))
self.assertEqual(M.size, math.prod(shape))
self.assertEqual(M.ndim, len(shape))
self.assertEqual(M.dtype, dtype)
self.assertEqual(M.nse, (M.todense() != 0).sum())
@ -2757,8 +2758,8 @@ class SparseRandomTest(sptu.SparseTestCase):
batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse])
approx_expected_num_nonzero = (
np.ceil(0.2 * np.prod(sparse_shape))
* np.prod(batch_shape) * np.prod(dense_shape))
np.ceil(0.2 * math.prod(sparse_shape))
* math.prod(batch_shape) * math.prod(dense_shape))
num_nonzero = (mat_dense != 0).sum()
self.assertAlmostEqual(int(num_nonzero), approx_expected_num_nonzero, delta=2)

View File

@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial
import math
import operator
from absl.testing import absltest
@ -35,7 +36,7 @@ config.parse_flags_with_absl()
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
def _rand_sparse(shape, dtype, nse=nse):
rand = rand_method(rng)
size = np.prod(shape).astype(int)
size = math.prod(shape)
if 0 <= nse < 1:
nse = nse * size
nse = min(size, int(nse))

View File

@ -278,9 +278,9 @@ class XMapTest(XMapTestCase):
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
a = jnp.arange(math.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
b = jnp.arange(math.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, a * 2)
self.assertAllClose(d, b * 4)
@ -292,9 +292,9 @@ class XMapTest(XMapTestCase):
out_axes=(['b', ...], {0: 'c'}),
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
ashape = (16, 8, 5)
a = jnp.arange(np.prod(ashape)).reshape(ashape)
a = jnp.arange(math.prod(ashape)).reshape(ashape)
bshape = (2, 7)
b = jnp.arange(np.prod(bshape)).reshape(bshape)
b = jnp.arange(math.prod(bshape)).reshape(bshape)
c, d = fm(a, b)
self.assertAllClose(c, (a * 2).sum(0))
self.assertAllClose(d, b * 4)
@ -330,7 +330,7 @@ class XMapTest(XMapTestCase):
fm = xmap(f, in_axes=['a', ...], out_axes=({}, {1: 'a'}),
axis_resources={'a': ('x', 'y')})
vshape = (4, 5)
v = jnp.arange(np.prod(vshape)).reshape(vshape)
v = jnp.arange(math.prod(vshape)).reshape(vshape)
ans, ans2 = fm(v)
self.assertAllClose(ans, (v * 2).sum(0))
self.assertAllClose(ans2, v.T * 4)
@ -344,7 +344,7 @@ class XMapTest(XMapTestCase):
fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
axis_resources={'a': ('y', 'x')})
vshape = (4, 5)
v = jnp.arange(np.prod(vshape)).reshape(vshape)
v = jnp.arange(math.prod(vshape)).reshape(vshape)
zxy = fxy(v)
zxy_op_sharding = zxy.sharding._to_xla_op_sharding(zxy.ndim)
self.assertListEqual(zxy_op_sharding.tile_assignment_dimensions, [1, 4])
@ -419,7 +419,7 @@ class XMapTest(XMapTestCase):
return h(y)
xshape = (4, 2, 5)
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape)
y = f(x)
self.assertAllClose(
y, ((jnp.sin(x * 2) *
@ -825,7 +825,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
axis_resources={'a': 'x', 'b': 'y'})
xshape = (8, 2, 4, 5)
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape)
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
self.assertIsNot(match, None)
@ -1649,7 +1649,7 @@ class XMapErrorTest(jtu.JaxTestCase):
return x
return h(x)
xshape = (2, 5, 6)
x = jnp.arange(np.prod(xshape)).reshape(xshape)
x = jnp.arange(math.prod(xshape)).reshape(xshape)
with self.assertRaisesRegex(RuntimeError,
"Changing the physical mesh is not allowed.*"):
f(x)