mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod has an unfortunate corner-case behavior that np.prod([]) returns a float. math.prod is available as of Python 3.8, and is a better solution here.
This commit is contained in:
parent
fb46d3d084
commit
5521423d92
@ -15,6 +15,7 @@
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import operator
|
||||
|
||||
import google_benchmark
|
||||
@ -54,7 +55,7 @@ def required_devices(num_devices_required):
|
||||
|
||||
|
||||
def create_mesh(shape, axis_names, state):
|
||||
size = np.prod(shape)
|
||||
size = math.prod(shape)
|
||||
if len(jax.devices()) < size:
|
||||
state.skip_with_error(f"Requires {size} devices")
|
||||
return None
|
||||
@ -421,7 +422,7 @@ def sda_index_8(state):
|
||||
def _sparse_bcoo_fromdense(state, jit: bool = False, compile: bool = False):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
size = math.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(
|
||||
@ -460,7 +461,7 @@ def sparse_bcoo_fromdense_compile(state):
|
||||
def _sparse_bcoo_todense(state, jit: bool = False, compile: bool = False):
|
||||
shape = (2000, 2000)
|
||||
nse = 10000
|
||||
size = np.prod(shape)
|
||||
size = math.prod(shape)
|
||||
rng = np.random.RandomState(1701)
|
||||
data = rng.randn(nse)
|
||||
indices = np.unravel_index(
|
||||
@ -600,7 +601,7 @@ def bench_addressable_shards_index(state):
|
||||
if mesh is None:
|
||||
return
|
||||
shape = (8, 2)
|
||||
inp = np.arange(np.prod(shape)).reshape(shape)
|
||||
inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
@ -614,7 +615,7 @@ def bench_addressable_shards_replica_id(state):
|
||||
if mesh is None:
|
||||
return
|
||||
shape = (64, 32)
|
||||
inp = np.arange(np.prod(shape)).reshape(shape)
|
||||
inp = np.arange(math.prod(shape)).reshape(shape)
|
||||
s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y'))
|
||||
arr = jax.device_put(inp, s)
|
||||
|
||||
@ -785,7 +786,7 @@ def pjit_aot_4000_device(state):
|
||||
def host_local_array_to_global_array(state):
|
||||
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
in_pspec = jax.sharding.PartitionSpec('x', 'y')
|
||||
|
||||
while state:
|
||||
|
@ -1869,8 +1869,8 @@ class DimensionHandler:
|
||||
Raise InconclusiveDimensionOperation if there is no such integer for all
|
||||
contexts,
|
||||
"""
|
||||
sz1 = int(np.prod(s1))
|
||||
sz2 = int(np.prod(s2))
|
||||
sz1 = math.prod(s1)
|
||||
sz2 = math.prod(s2)
|
||||
if sz1 == 0 and sz2 == 0:
|
||||
return 1
|
||||
if sz1 % sz2:
|
||||
|
@ -1624,10 +1624,10 @@ def manual_proto(
|
||||
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
||||
[axis_order[a] for a in manual_axes])
|
||||
tad_shape = [1] * aval.ndim
|
||||
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in replicated_axes], dtype=int)))
|
||||
tad_shape.append(int(np.prod([named_mesh_shape[a] for a in manual_axes], dtype=int)))
|
||||
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
|
||||
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
|
||||
|
||||
raw_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
||||
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
proto = xc.OpSharding()
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = tad_shape
|
||||
|
@ -1256,7 +1256,7 @@ mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu')
|
||||
@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)')
|
||||
def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
|
||||
m = lu.shape[0]
|
||||
x = jnp.reshape(b, (m, np.prod(b.shape[1:])))
|
||||
x = jnp.reshape(b, (m, math.prod(b.shape[1:])))
|
||||
if trans == 0:
|
||||
x = x[permutation, :]
|
||||
x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import enum
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
import weakref
|
||||
|
||||
@ -1891,7 +1892,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts,
|
||||
# on the IDs.
|
||||
ids_shape = np.array(updates.shape, dtype=np.int64)
|
||||
ids_shape[dnums.update_window_dims,] = 1
|
||||
num_ids = np.prod(ids_shape)
|
||||
num_ids = math.prod(ids_shape)
|
||||
id_dtype = np.uint32 if (num_ids + 1) < np.iinfo(np.uint32).max else np.uint64
|
||||
update_ids = lax.add(lax.reshape(lax.iota(id_dtype, num_ids), ids_shape),
|
||||
lax._ones(updates, dtype=id_dtype))
|
||||
|
@ -19,6 +19,7 @@ from collections import OrderedDict, abc
|
||||
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
NamedTuple, Union, Sequence, Mapping)
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
import math
|
||||
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
@ -1573,9 +1574,9 @@ def _get_axis_resource_count(
|
||||
if local_res_shape is None:
|
||||
nlocal = None
|
||||
else:
|
||||
nlocal = int(np.prod(map(local_res_shape.get, resources), dtype=np.int64))
|
||||
nlocal = math.prod(map(local_res_shape.get, resources))
|
||||
resource_count_map[axis] = ResourceCount(
|
||||
int(np.prod(map(global_res_shape.get, resources), dtype=np.int64)),
|
||||
math.prod(map(global_res_shape.get, resources)),
|
||||
nlocal, distributed)
|
||||
return resource_count_map
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import collections
|
||||
import contextlib
|
||||
import functools
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Hashable, NamedTuple, Set, Sequence, Tuple, Union
|
||||
|
||||
@ -199,7 +200,7 @@ class Mesh(contextlib.ContextDecorator):
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return np.prod(list(self.shape.values()))
|
||||
return math.prod(self.shape.values())
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
|
@ -163,15 +163,15 @@ def _compute_fans(shape: core.NamedShape,
|
||||
if isinstance(in_axis, int):
|
||||
in_size = shape[in_axis]
|
||||
else:
|
||||
in_size = int(np.prod([shape[i] for i in in_axis]))
|
||||
in_size = math.prod([shape[i] for i in in_axis])
|
||||
if isinstance(out_axis, int):
|
||||
out_size = shape[out_axis]
|
||||
else:
|
||||
out_size = int(np.prod([shape[i] for i in out_axis]))
|
||||
out_size = math.prod([shape[i] for i in out_axis])
|
||||
if isinstance(batch_axis, int):
|
||||
batch_size = shape[batch_axis]
|
||||
else:
|
||||
batch_size = int(np.prod([shape[i] for i in batch_axis]))
|
||||
batch_size = math.prod([shape[i] for i in batch_axis])
|
||||
receptive_field_size = shape.total / in_size / out_size / batch_size
|
||||
fan_in = in_size * receptive_field_size
|
||||
fan_out = out_size * receptive_field_size
|
||||
|
@ -747,7 +747,7 @@ def _quantile(a: Array, q: Array, axis: Optional[Union[int, Tuple[int, ...]]],
|
||||
dimensions[i], dimensions[s] = dimensions[s], dimensions[i]
|
||||
do_not_touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx not in axis)
|
||||
touch_shape = tuple(x for idx,x in enumerate(a.shape) if idx in axis)
|
||||
a = lax.reshape(a, do_not_touch_shape + (int(np.prod(touch_shape)),), dimensions)
|
||||
a = lax.reshape(a, do_not_touch_shape + (math.prod(touch_shape),), dimensions)
|
||||
axis = _canonicalize_axis(-1, a.ndim)
|
||||
else:
|
||||
axis = _canonicalize_axis(axis, a.ndim)
|
||||
|
@ -1043,7 +1043,7 @@ def iota_2x32_shape(shape):
|
||||
Setting aside representation, this function essentially computes the
|
||||
equivalent of::
|
||||
|
||||
jax.lax.iota(dtype=np.uint64, size=np.prod(shape)).reshape(shape)
|
||||
jax.lax.iota(dtype=np.uint64, size=math.prod(shape)).reshape(shape)
|
||||
|
||||
However:
|
||||
|
||||
@ -1069,7 +1069,7 @@ def iota_2x32_shape(shape):
|
||||
[ 8, 9, 10, 11]], dtype=uint32)]
|
||||
|
||||
>>> def reshaped_iota(shape):
|
||||
... return lax.iota(size=np.prod(shape), dtype=np.uint32).reshape(shape)
|
||||
... return lax.iota(size=math.prod(shape), dtype=np.uint32).reshape(shape)
|
||||
...
|
||||
>>> reshaped_iota((3, 4))
|
||||
Array([[ 0, 1, 2, 3],
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import Callable, Optional, Tuple, Union, Sequence
|
||||
import warnings
|
||||
@ -232,7 +233,7 @@ def _fft_helper(x: Array, win: Array, detrend_func: Callable[[Array], Array],
|
||||
else:
|
||||
step = nperseg - noverlap
|
||||
batch_shape = list(batch_shape)
|
||||
x = x.reshape((int(np.prod(batch_shape)), signal_length))[..., np.newaxis]
|
||||
x = x.reshape((math.prod(batch_shape)), signal_length)[..., np.newaxis]
|
||||
result = jax.lax.conv_general_dilated_patches(
|
||||
x, (nperseg,), (step,),
|
||||
'VALID',
|
||||
@ -579,7 +580,7 @@ def _overlap_and_add(x: Array, step_size: int) -> Array:
|
||||
raise ValueError('Input must have (..., frames, frame_length) shape.')
|
||||
|
||||
*batch_shape, nframes, segment_len = x.shape
|
||||
flat_batchsize = np.prod(batch_shape, dtype=np.int64)
|
||||
flat_batchsize = math.prod(batch_shape)
|
||||
x = x.reshape((flat_batchsize, nframes, segment_len))
|
||||
output_size = step_size * (nframes - 1) + segment_len
|
||||
nstep_per_segment = 1 + (segment_len - 1) // step_size
|
||||
|
@ -75,7 +75,7 @@ def _sharding_spec_mesh_shape(self):
|
||||
|
||||
|
||||
def get_logical_mesh_ids(mesh_shape):
|
||||
return np.arange(np.prod(mesh_shape)).reshape(mesh_shape)
|
||||
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
|
||||
|
||||
_MeshAxisName = Any
|
||||
@ -123,7 +123,7 @@ def sharding_spec_sharding_proto(
|
||||
assert mesh_shape[maxis] == nchunks
|
||||
mesh_permutation.append(maxis)
|
||||
next_sharded_axis += 1
|
||||
new_mesh_shape.append(int(np.prod(sharding.chunks)))
|
||||
new_mesh_shape.append(math.prod(sharding.chunks))
|
||||
elif isinstance(sharding, Unstacked):
|
||||
raise RuntimeError("Cannot convert unstacked sharding specs to XLA OpSharding")
|
||||
else:
|
||||
@ -187,7 +187,7 @@ def _sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
axis_indices.append(range(axis_size))
|
||||
shard_indices_shape.append(axis_size)
|
||||
elif isinstance(sharding, Chunked):
|
||||
total_chunks = int(np.prod(sharding.chunks))
|
||||
total_chunks = math.prod(sharding.chunks)
|
||||
shard_size, ragged = divmod(axis_size, total_chunks)
|
||||
assert not ragged, (axis_size, total_chunks, dim)
|
||||
axis_indices.append([slice(i * shard_size, (i + 1) * shard_size)
|
||||
|
@ -487,7 +487,7 @@ def rand_fullrange(rng, standardize_nans=False):
|
||||
"""Random numbers that span the full range of available bits."""
|
||||
def gen(shape, dtype, post=lambda x: x):
|
||||
dtype = np.dtype(dtype)
|
||||
size = dtype.itemsize * np.prod(_dims_of_shape(shape), dtype=int)
|
||||
size = dtype.itemsize * math.prod(_dims_of_shape(shape))
|
||||
vals = rng.randint(0, np.iinfo(np.uint8).max, size=size, dtype=np.uint8)
|
||||
vals = post(vals).view(dtype)
|
||||
if shape is PYTHON_SCALAR_SHAPE:
|
||||
|
@ -498,6 +498,7 @@ import atexit
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import traceback
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence,
|
||||
@ -1380,7 +1381,7 @@ def _add_transform(params: Dict, name: str, *transform_params) -> Dict:
|
||||
|
||||
|
||||
def _aval_is_empty(aval) -> bool:
|
||||
return np.prod(aval.shape) == 0
|
||||
return math.prod(aval.shape) == 0
|
||||
|
||||
def _instantiate_zeros(tan, arg):
|
||||
"""Turn special ad.zero tangents into arrays of 0s for sending to host.
|
||||
|
@ -15,6 +15,7 @@
|
||||
import builtins
|
||||
import dataclasses
|
||||
from functools import partial, wraps
|
||||
import math
|
||||
import string
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
|
||||
|
||||
@ -665,7 +666,7 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
|
||||
# TODO(marcvanzee): This may give very large deviations on TPU when using
|
||||
# floats as inputs. Alternatively, we could implement this using a
|
||||
# convolution with an all-1's kernel.
|
||||
return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions))
|
||||
return tf.multiply(tf_pool(operand, "AVG"), math.prod(window_dimensions))
|
||||
|
||||
|
||||
def _reduce_window(*args, jaxpr, consts, window_dimensions,
|
||||
@ -951,7 +952,7 @@ def _gather_generate_indices(shape: Tuple[int, ...]):
|
||||
"""
|
||||
Returns the indices of the according to `shape`:
|
||||
each element in the output is the index of an element of an array
|
||||
of the provided shape. The result's shape is (np.prod(shape), len(shape))
|
||||
of the provided shape. The result's shape is (math.prod(shape), len(shape))
|
||||
|
||||
For example, given shape (2,2) it returns (0,0),(0,1),(1,0),(1,1)
|
||||
"""
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Provides JAX and TensorFlow interoperation APIs."""
|
||||
from functools import partial
|
||||
import contextlib
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import re
|
||||
@ -3094,7 +3095,7 @@ def split_to_logical_devices(tensor: TfVal,
|
||||
# of _shard_values.
|
||||
if partition_dimensions is None:
|
||||
return xla_sharding.replicate(tensor, use_sharding_op=True)
|
||||
num_partition_splits = np.prod(partition_dimensions)
|
||||
num_partition_splits = math.prod(partition_dimensions)
|
||||
tile_assignment = np.arange(num_partition_splits).reshape(
|
||||
partition_dimensions)
|
||||
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
|
||||
|
@ -304,7 +304,7 @@ class _DimMon(dict):
|
||||
assert a_l <= a_u
|
||||
bounds.append((a_l ** exp, a_u ** exp))
|
||||
|
||||
candidates = [np.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
|
||||
candidates = [math.prod(atom_bounds) for atom_bounds in itertools.product(*bounds)]
|
||||
return (min(*candidates), max(*candidates)) # type: ignore
|
||||
|
||||
|
||||
@ -696,8 +696,8 @@ class DimensionHandlerPoly(core.DimensionHandler):
|
||||
return _ensure_poly(d1, "ge") >= d2
|
||||
|
||||
def divide_shape_sizes(self, s1: Shape, s2: Shape) -> DimSize:
|
||||
sz1 = np.prod(s1)
|
||||
sz2 = np.prod(s2)
|
||||
sz1 = math.prod(s1)
|
||||
sz2 = math.prod(s2)
|
||||
if core.symbolic_equal_dim(sz1, sz2): # Takes care also of sz1 == sz2 == 0
|
||||
return 1
|
||||
err_msg = f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}"
|
||||
|
@ -62,6 +62,7 @@ import dataclasses
|
||||
import datetime
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@ -346,7 +347,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
@staticmethod
|
||||
def eigh_input(shape, dtype):
|
||||
# In order to keep inputs small, we construct the input programmatically
|
||||
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
|
||||
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
||||
# Make operand self-adjoint
|
||||
operand = (operand + jnp.conj(jnp.swapaxes(operand, -1, -2))) / 2.
|
||||
return operand
|
||||
@ -417,7 +418,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
@staticmethod
|
||||
def qr_harness(shape, dtype):
|
||||
# In order to keep inputs small, we construct the input programmatically
|
||||
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
|
||||
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
||||
return lax.linalg.qr(operand, full_matrices=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -458,7 +459,7 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
|
||||
@staticmethod
|
||||
def lu_harness(shape, dtype):
|
||||
operand = jnp.reshape(jnp.arange(np.prod(shape), dtype=dtype), shape)
|
||||
operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=dtype), shape)
|
||||
return lax.linalg.lu(operand)
|
||||
|
||||
def test_tpu_Lu(self):
|
||||
|
@ -16,6 +16,7 @@
|
||||
Specific JAX primitive conversion tests are in primitives_test."""
|
||||
import collections
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
@ -1286,7 +1287,7 @@ class Jax2TfTest(tf_test_util.JaxToTfTestCase):
|
||||
return carry
|
||||
|
||||
shape = (3, 2)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
jax_comp = jax.xla_computation(f_while)(x)
|
||||
backend = jax._src.xla_bridge.get_backend()
|
||||
@ -1723,7 +1724,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 10)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
in_axis_resources = (P("x"), P("x"))
|
||||
out_axis_resources = None
|
||||
res_jax = pjit.pjit(
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Tests for the shape-polymorphic jax2tf conversion."""
|
||||
import contextlib
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@ -1280,7 +1281,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
# https://github.com/google/jax/issues/7093
|
||||
# Also issue #6975.
|
||||
x_shape = (2, 3, 4)
|
||||
xi = np.arange(np.prod(x_shape), dtype=np.int16).reshape(x_shape)
|
||||
xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape)
|
||||
yf = xi.astype(np.float32)
|
||||
xi_yf = (xi, yf)
|
||||
zb = np.array([True, False], dtype=np.bool_)
|
||||
@ -1350,7 +1351,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
|
||||
f_tf = tf.function(f_tf, autograph=False)
|
||||
x_shape = (2, 3, 4)
|
||||
x = np.arange(np.prod(x_shape), dtype=np.int32).reshape(x_shape)
|
||||
x = np.arange(math.prod(x_shape), dtype=np.int32).reshape(x_shape)
|
||||
|
||||
# When saving the model with gradients, we trace the gradient function
|
||||
# and we used to get an error when creating zeros_like_aval for a
|
||||
@ -1378,7 +1379,7 @@ class ShapePolyTest(tf_test_util.JaxToTfTestCase):
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (np.prod(x.shape),)),
|
||||
jax2tf.convert(lambda x: jnp.reshape(x, (math.prod(x.shape),)),
|
||||
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
|
||||
|
||||
jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]),
|
||||
|
@ -628,7 +628,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest("TODO(b/268295912): ShardingRemover crash")
|
||||
|
||||
mesh = Mesh(self.devices, axis_names=('x'))
|
||||
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
|
||||
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
||||
|
||||
@partial(pjit.pjit,
|
||||
in_shardings=(P('x', None),), out_shardings=P(None, 'x'))
|
||||
@ -680,7 +680,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
return lax.ppermute(b, 'x', perm=perm)
|
||||
|
||||
with mesh:
|
||||
a = np.arange(np.prod(4 * 4)).reshape((4, 4))
|
||||
a = np.arange(4 * 4).reshape((4, 4))
|
||||
res_jax = f_jax(a)
|
||||
b0, b1 = np.split(a, 2, axis=0) # The shard_map splits on axis 0
|
||||
b0, b1 = b1, b0
|
||||
@ -705,7 +705,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
if poly is not None:
|
||||
raise unittest.SkipTest("TODO: Sharding custom calls lack shape refinement")
|
||||
mesh = Mesh(self.devices, axis_names=('x'))
|
||||
a = np.arange(np.prod(4 * 4), dtype=np.float32).reshape((4, 4))
|
||||
a = np.arange(4 * 4, dtype=np.float32).reshape((4, 4))
|
||||
|
||||
@partial(pjit.pjit,
|
||||
in_shardings=(P('x', None),), out_shardings=P('x', None))
|
||||
|
@ -1575,7 +1575,7 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
|
||||
n = current_n_dense - n_dense
|
||||
@partial(nfold_vmap, N=current_n_batch + 1)
|
||||
def _update(d, i):
|
||||
new_d = d.reshape(np.prod(d.shape[:n]), *d.shape[n:])
|
||||
new_d = d.reshape(math.prod(d.shape[:n]), *d.shape[n:])
|
||||
meshes = jnp.meshgrid(*(jnp.arange(s, dtype=i.dtype) for s in d.shape[:n]),
|
||||
indexing='ij')
|
||||
new_i = jnp.column_stack([jnp.broadcast_to(i, (new_d.shape[0], i.size)),
|
||||
@ -1583,10 +1583,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
|
||||
return new_d, new_i
|
||||
new_data, new_indices = _update(new_data, new_indices)
|
||||
new_data = new_data.reshape(*new_data.shape[:current_n_batch],
|
||||
np.prod(new_data.shape[current_n_batch:current_n_batch + 2]),
|
||||
math.prod(new_data.shape[current_n_batch:current_n_batch + 2]),
|
||||
*new_data.shape[current_n_batch + 2:])
|
||||
new_indices = new_indices.reshape(*new_indices.shape[:current_n_batch],
|
||||
np.prod(new_indices.shape[current_n_batch: current_n_batch + 2]),
|
||||
math.prod(new_indices.shape[current_n_batch: current_n_batch + 2]),
|
||||
*new_indices.shape[current_n_batch + 2:])
|
||||
current_n_dense = n_dense
|
||||
|
||||
@ -1595,10 +1595,10 @@ def bcoo_update_layout(mat: BCOO, *, n_batch: Optional[int] = None, n_dense: Opt
|
||||
@partial(nfold_vmap, N=n_batch)
|
||||
def _update(d, i):
|
||||
nse = i.shape[-2]
|
||||
new_d = d.reshape(np.prod(d.shape[:n + 1]), *d.shape[n + 1:])
|
||||
new_d = d.reshape(math.prod(d.shape[:n + 1]), *d.shape[n + 1:])
|
||||
meshes = jnp.meshgrid(*(jnp.arange(d, dtype=i.dtype) for d in (*i.shape[:n], nse)),
|
||||
indexing='ij')
|
||||
new_i = i.reshape(np.prod(i.shape[:n + 1]), *i.shape[n + 1:])
|
||||
new_i = i.reshape(math.prod(i.shape[:n + 1]), *i.shape[n + 1:])
|
||||
new_i = jnp.column_stack((*(m.ravel() for m in meshes[:-1]), new_i))
|
||||
return new_d, new_i
|
||||
new_data, new_indices = _update(new_data, new_indices)
|
||||
@ -1669,7 +1669,7 @@ def _bcoo_broadcast_in_dim(data: Array, indices: Array, *, spinfo: SparseInfo, s
|
||||
new_n_dense = props.n_dense and len(shape) - min(broadcast_dimensions[-props.n_dense:])
|
||||
new_n_sparse = len(shape) - new_n_batch - new_n_dense
|
||||
|
||||
if np.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != np.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
|
||||
if math.prod(spinfo.shape[props.n_batch: props.n_batch + props.n_sparse]) != math.prod(shape[new_n_batch:new_n_batch + new_n_sparse]):
|
||||
raise NotImplementedError("Adding sparse dimensions with lengths != 1")
|
||||
new_data, new_indices = data, indices
|
||||
|
||||
@ -1784,8 +1784,8 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in
|
||||
batch_shape, sparse_shape, dense_shape = split_list(mat.shape, [mat.n_batch, mat.n_sparse])
|
||||
batch_perm, sparse_perm, dense_perm = _validate_permutation(
|
||||
mat.data, mat.indices, dimensions or tuple(range(mat.ndim)), mat.shape)
|
||||
batch_size = np.prod(batch_shape, dtype=int)
|
||||
sparse_size = np.prod(sparse_shape, dtype=int)
|
||||
batch_size = math.prod(batch_shape)
|
||||
sparse_size = math.prod(sparse_shape)
|
||||
|
||||
cuml_shape = np.cumprod(new_sizes)
|
||||
if batch_size != 1 and batch_size not in cuml_shape:
|
||||
@ -1797,7 +1797,7 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in
|
||||
|
||||
i1 = cuml_shape.searchsorted(batch_size, side='right')
|
||||
i2 = cuml_shape.searchsorted(batch_size * sparse_size, side='right')
|
||||
new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [i1, i2])
|
||||
new_batch_shape, new_sparse_shape, new_dense_shape = split_list(new_sizes, [int(i1), int(i2)])
|
||||
|
||||
# Reshape batch & dense dimensions: this is accomplished via a standard reshape.
|
||||
data = lax.reshape(
|
||||
@ -1949,7 +1949,7 @@ def bcoo_slice(mat: BCOO, *, start_indices: Sequence[int], limit_indices: Sequen
|
||||
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
|
||||
new_data = jnp.where(keep_data, new_data, 0)
|
||||
|
||||
new_nse = int(np.prod(new_shape_sparse))
|
||||
new_nse = math.prod(new_shape_sparse)
|
||||
if mat.nse > new_nse:
|
||||
new_data, new_indices = _bcoo_sum_duplicates(
|
||||
new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse)
|
||||
@ -2029,8 +2029,8 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq
|
||||
keep_data = lax.expand_dims(keep[..., 0], range(mat.n_batch + 1, mat.n_batch + 1 + mat.n_dense))
|
||||
new_data = jnp.where(keep_data, new_data, 0)
|
||||
|
||||
if mat.nse > np.prod(size_sparse):
|
||||
new_nse = int(np.prod(size_sparse))
|
||||
if mat.nse > math.prod(size_sparse):
|
||||
new_nse = math.prod(size_sparse)
|
||||
new_data, new_indices = _bcoo_sum_duplicates(
|
||||
new_data, new_indices, spinfo=SparseInfo(shape=new_shape), nse=new_nse)
|
||||
|
||||
@ -2099,7 +2099,7 @@ def _bcoo_reduce_sum(data: Array, indices: Array, *, spinfo: SparseInfo, axes: S
|
||||
|
||||
new_batch_dims = tuple(sorted(set(range(n_batch)) - batch_axes))
|
||||
new_batch_shape = tuple(data.shape[i] for i in new_batch_dims)
|
||||
new_nse = int(nse * np.prod([data.shape[i] for i in batch_axes]))
|
||||
new_nse = nse * math.prod([data.shape[i] for i in batch_axes])
|
||||
|
||||
data = lax.reshape(data,
|
||||
(*new_batch_shape, new_nse, *data.shape[n_batch + 1:]),
|
||||
|
@ -908,7 +908,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.skipTest('Test needs >= 4 devices.')
|
||||
ps = jax.sharding.PmapSharding.default(shape, sharded_dim)
|
||||
|
||||
inp = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
inp = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
compiled = jax.pmap(lambda x: x, in_axes=sharded_dim).lower(inp).compile()
|
||||
pmap_in_sharding, = compiled._executable.unsafe_call.in_handler.in_shardings
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -206,7 +207,7 @@ mlir.register_lowering(
|
||||
|
||||
def make_sparse_array(rng, shape, dtype, nnz=0.2):
|
||||
mat = rng(shape, dtype)
|
||||
size = int(np.prod(shape))
|
||||
size = math.prod(shape)
|
||||
if 0 < nnz < 1:
|
||||
nnz = nnz * size
|
||||
nnz = int(nnz)
|
||||
|
@ -3805,16 +3805,15 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testTakeAlongAxis(self, x_shape, i_shape, dtype, index_dtype, axis):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
i_shape = np.array(i_shape)
|
||||
i_shape = list(i_shape)
|
||||
if axis is None:
|
||||
i_shape = [np.prod(i_shape, dtype=np.int64)]
|
||||
i_shape = [math.prod(i_shape)]
|
||||
else:
|
||||
# Test the case where the size of the axis doesn't necessarily broadcast.
|
||||
i_shape[axis] *= 3
|
||||
i_shape = list(i_shape)
|
||||
def args_maker():
|
||||
x = rng(x_shape, dtype)
|
||||
n = np.prod(x_shape, dtype=np.int32) if axis is None else x_shape[axis]
|
||||
n = math.prod(x_shape) if axis is None else x_shape[axis]
|
||||
if np.issubdtype(index_dtype, np.unsignedinteger):
|
||||
index_rng = jtu.rand_int(self.rng(), 0, n)
|
||||
else:
|
||||
|
@ -641,7 +641,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
if c == 'N':
|
||||
self.assertEqual(out_c, patch_c)
|
||||
elif c == 'C':
|
||||
self.assertEqual(out_c * np.prod(filter_shape), patch_c)
|
||||
self.assertEqual(out_c * math.prod(filter_shape), patch_c)
|
||||
else:
|
||||
self.assertEqual(out_c, patch_c * filter_shape[filter_spec.index(c)])
|
||||
|
||||
|
@ -226,8 +226,8 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
x_shape = (8, 6, 4)
|
||||
y_shape = (4, 2)
|
||||
x = jnp.arange(np.prod(x_shape)).reshape(x_shape)
|
||||
y = jnp.arange(np.prod(y_shape)).reshape(y_shape)
|
||||
x = jnp.arange(math.prod(x_shape)).reshape(x_shape)
|
||||
y = jnp.arange(math.prod(y_shape)).reshape(y_shape)
|
||||
actual = f(x, y)
|
||||
expected = x @ y
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -285,7 +285,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
actual = f(x, x + 1)
|
||||
expected = x @ (x + 1)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -753,7 +753,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
return x + y + z + w
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = x * 2.
|
||||
z = x * 3.
|
||||
w = x * 4.
|
||||
@ -822,7 +822,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
token = lax.outfeed(token, x, partitions=(P(1, nr_devices),))
|
||||
return x
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
def _dispatch():
|
||||
with jax.sharding.Mesh(devices, ['d']):
|
||||
@ -863,7 +863,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
expected = x @ (x + 1)
|
||||
|
||||
lowered = f.lower(x, x + 1)
|
||||
@ -896,7 +896,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
exe = f.lower(x, x + 1, a=1, b=2).compile()
|
||||
out = exe(x, x + 1, a=1, b=2)
|
||||
self.assertArraysEqual(out, x @ (x + 1))
|
||||
@ -910,7 +910,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
exe = f.lower(x, x + 1).compile()
|
||||
|
||||
self.assertRaisesRegex(
|
||||
@ -926,7 +926,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
x_f32 = x.astype(jnp.float32)
|
||||
x_i32 = x.astype(jnp.int32)
|
||||
exe = f.lower(x_f32, x_f32).compile()
|
||||
@ -947,7 +947,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
@ -963,7 +963,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
@ -995,7 +995,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1).compile()
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
|
||||
@ -1008,7 +1008,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1).compile()
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
|
||||
@ -1022,7 +1022,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1)
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@ -1036,7 +1036,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1).compile()
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@ -1050,7 +1050,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = f.lower(x, x + 1).compile()
|
||||
f.memory_analysis() # doesn't raise
|
||||
|
||||
@ -1063,7 +1063,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
f = f.lower(x, x + 1).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
@ -1088,7 +1088,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
|
||||
shape = (8, 8)
|
||||
aval = core.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
exe = f.lower(aval, x).compile()
|
||||
self.assertIsInstance(exe, stages.Compiled)
|
||||
self.assertArraysEqual(exe(x, x), x @ x)
|
||||
@ -2158,7 +2158,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
return x if c == 0 else x + 1
|
||||
|
||||
shape = (8, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
exe = f.lower(1, x).compile()
|
||||
|
||||
self.assertAllClose(exe(x), x + 1, check_dtypes=False)
|
||||
@ -3205,7 +3205,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
def testNonDivisibleArgs(self, mesh, resources):
|
||||
x = jnp.ones((3, 2))
|
||||
spec = P(resources, None)
|
||||
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
|
||||
mesh_size = str(math.prod([dim[1] for dim in mesh]))
|
||||
error = re.compile(
|
||||
r"One of pjit arguments.*" + spec_regex(spec) + r".*"
|
||||
r"implies that the global size of its dimension 0 should be "
|
||||
@ -3218,7 +3218,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
def testNonDivisibleOuts(self, mesh, resources):
|
||||
x = jnp.ones((3, 2))
|
||||
spec = P(resources, None)
|
||||
mesh_size = str(np.prod([dim[1] for dim in mesh], dtype=np.int64))
|
||||
mesh_size = str(math.prod([dim[1] for dim in mesh]))
|
||||
error = re.compile(
|
||||
r"One of pjit outputs.*" + spec_regex(spec) + r".*"
|
||||
r"implies that the global size of its dimension 0 should be "
|
||||
@ -3456,7 +3456,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
return x
|
||||
return h(x)
|
||||
xshape = (2, 5, 6)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
x = jnp.arange(math.prod(xshape)).reshape(xshape)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Changing the physical mesh is not allowed.*"):
|
||||
f(x)
|
||||
@ -3506,7 +3506,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
FakeDevice = namedtuple('FakeDevice', ['id'])
|
||||
mesh_named_shape = OrderedDict([('a', 2), ('b', 3), ('c', 4), ('d', 7), ('e', 4)])
|
||||
mesh_axes, mesh_shape = unzip2(mesh_named_shape.items())
|
||||
devices = [FakeDevice(i) for i in range(np.prod(list(mesh_shape)))]
|
||||
devices = [FakeDevice(i) for i in range(math.prod(mesh_shape))]
|
||||
mesh = pxla.Mesh(np.array(devices).reshape(*mesh_shape), tuple(mesh_axes))
|
||||
|
||||
dims = 5
|
||||
|
@ -581,7 +581,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testAllToAll(self, split_axis, concat_axis):
|
||||
pmap_in_axis = 0
|
||||
shape = (jax.device_count(),) * 3
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
def f(x):
|
||||
@ -602,7 +602,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
raise SkipTest("test requires at least four devices")
|
||||
pmap_in_axis = 0
|
||||
shape = (4, 4, 4)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
@partial(self.pmap, axis_name='j')
|
||||
@ -2400,7 +2400,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
verify_ref()
|
||||
|
||||
shape = (jax.device_count(),) * 5
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
self.assertAllClose(pmap(vmap(f, in_axes=vmap_axis), axis_name='i')(x),
|
||||
reference(x, split_axis, concat_axis, vmap_axis))
|
||||
|
||||
@ -2413,7 +2413,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
return lax.all_to_all(x, 'i', split_axis=split_axis, concat_axis=concat_axis)
|
||||
|
||||
shape = (jax.device_count(),) * 4
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
self.assertAllClose(pmap(f, axis_name='i')(x),
|
||||
vmap(f, axis_name='i')(x))
|
||||
|
||||
@ -2431,7 +2431,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
return lax.all_to_all(x, axes, split_axis=split_axis, concat_axis=concat_axis)
|
||||
|
||||
shape = (2, 2, 4, 4, 4)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
self.assertAllClose(pmap(pmap(f, axis_name='j'), axis_name='i')(x),
|
||||
vmap(vmap(f, axis_name='j'), axis_name='i')(x))
|
||||
|
||||
@ -2456,7 +2456,7 @@ class VmapPmapCollectivesTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < 4:
|
||||
raise SkipTest("test requires at least four devices")
|
||||
shape = (4, 4, 8)
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
f = partial(prim, axis_name='i', tiled=tiled)
|
||||
self.assertAllClose(vmap(f, axis_name='i')(x), pmap(f, axis_name='i')(x))
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
from unittest import SkipTest, skipIf
|
||||
from typing import Any, Tuple, NamedTuple, Optional
|
||||
import zlib
|
||||
@ -689,7 +690,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
for ndim in [1 if is_range else len(input_range_or_shape)]
|
||||
for axis in range(-ndim, ndim or 1)
|
||||
for ninputs in [input_range_or_shape if is_range else input_range_or_shape[axis]]
|
||||
if replace or np.prod(shape) <= ninputs
|
||||
if replace or math.prod(shape) <= ninputs
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.integer,
|
||||
weighted=[True, False],
|
||||
@ -702,7 +703,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
key = self.seed_prng(0)
|
||||
is_range = type(input_range_or_shape) is int
|
||||
x = (input_range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(np.prod(
|
||||
self.rng().permutation(np.arange(math.prod(
|
||||
input_range_or_shape), dtype=dtype)).reshape(input_range_or_shape))
|
||||
N = x if is_range else x.shape[axis]
|
||||
if weighted:
|
||||
@ -718,7 +719,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertEqual(np_shape, sample.shape)
|
||||
if not replace and shape:
|
||||
def lsort(x):
|
||||
if not np.prod(x.shape): return x
|
||||
if not math.prod(x.shape): return x
|
||||
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
|
||||
return jnp.take(x, ind, axis)
|
||||
self.assertArraysEqual(lsort(sample), lsort(np.unique(sample, axis=axis)))
|
||||
@ -741,7 +742,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
is_range = type(range_or_shape) is int
|
||||
x = (range_or_shape if is_range else
|
||||
self.rng().permutation(np.arange(
|
||||
np.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
|
||||
math.prod(range_or_shape), dtype=dtype)).reshape(range_or_shape))
|
||||
shape = ((range_or_shape,) if is_range else range_or_shape)
|
||||
x_ = np.copy(x)
|
||||
rand = lambda key, x: random.permutation(key, x, axis, independent=independent)
|
||||
@ -750,7 +751,7 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
self.assertFalse(np.all(perm == x)) # seems unlikely!
|
||||
arr = np.arange(x) if is_range else x
|
||||
def lsort(x):
|
||||
if not np.prod(x.shape): return x
|
||||
if not math.prod(x.shape): return x
|
||||
ind = np.lexsort(np.swapaxes(x, axis, -1).reshape((-1, x.shape[axis])))
|
||||
return jnp.take(x, ind, axis)
|
||||
if not independent:
|
||||
@ -1602,7 +1603,7 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
def make_keys(self, *shape, seed=None):
|
||||
if seed is None:
|
||||
seed = 28
|
||||
seeds = seed + jnp.arange(np.prod(shape), dtype=jnp.uint32)
|
||||
seeds = seed + jnp.arange(math.prod(shape), dtype=jnp.uint32)
|
||||
make_key = partial(prng.seed_with_impl, prng.threefry_prng_impl)
|
||||
return jnp.reshape(jax.vmap(make_key)(seeds), shape)
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
import contextlib
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import random
|
||||
import unittest
|
||||
@ -174,7 +175,7 @@ all_dtypes = jtu.dtypes.integer + jtu.dtypes.floating + jtu.dtypes.complex
|
||||
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||
def _rand_sparse(shape, dtype, nse=nse):
|
||||
rand = rand_method(rng)
|
||||
size = np.prod(shape).astype(int)
|
||||
size = math.prod(shape)
|
||||
if 0 <= nse < 1:
|
||||
nse = nse * size
|
||||
nse = min(size, int(nse))
|
||||
@ -1554,7 +1555,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertEqual(mat.n_dense, out.n_dense)
|
||||
|
||||
# Unnecessary padding eliminated
|
||||
max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
|
||||
max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
|
||||
self.assertLessEqual(out.nse, max_nse)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -1589,7 +1590,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
self.assertEqual(mat.n_dense, out.n_dense)
|
||||
|
||||
# Unnecessary padding eliminated
|
||||
max_nse = np.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
|
||||
max_nse = math.prod(out.shape[out.n_batch: out.n_batch + out.n_sparse])
|
||||
self.assertLessEqual(out.nse, max_nse)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -1644,7 +1645,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
[dict(shape=shape, n_batch=layout.n_batch, n_dense=layout.n_dense, nse=nse)
|
||||
for shape in [(5,), (5, 8), (8, 5), (3, 4, 5), (3, 4, 3, 2)]
|
||||
for layout in iter_sparse_layouts(shape)
|
||||
for nse in [None, np.prod(shape) - 1]
|
||||
for nse in [None, math.prod(shape) - 1]
|
||||
],
|
||||
dtype=jtu.dtypes.floating + jtu.dtypes.complex,
|
||||
remove_zeros=[True, False],
|
||||
@ -2106,7 +2107,7 @@ class BCOOTest(sptu.SparseTestCase):
|
||||
M = rng(shape, dtype)
|
||||
|
||||
def make_bcoo(M):
|
||||
return sparse_bcoo._bcoo_fromdense(M, nse=np.prod(M.shape[:-1], dtype=int), n_dense=1)
|
||||
return sparse_bcoo._bcoo_fromdense(M, nse=math.prod(M.shape[:-1]), n_dense=1)
|
||||
|
||||
todense = partial(sparse_bcoo._bcoo_todense, spinfo=sparse_util.SparseInfo(shape))
|
||||
|
||||
@ -2586,7 +2587,7 @@ class SparseObjectTest(sptu.SparseTestCase):
|
||||
|
||||
self.assertIsInstance(M, Obj)
|
||||
self.assertEqual(M.shape, shape)
|
||||
self.assertEqual(M.size, np.prod(shape))
|
||||
self.assertEqual(M.size, math.prod(shape))
|
||||
self.assertEqual(M.ndim, len(shape))
|
||||
self.assertEqual(M.dtype, dtype)
|
||||
self.assertEqual(M.nse, (M.todense() != 0).sum())
|
||||
@ -2757,8 +2758,8 @@ class SparseRandomTest(sptu.SparseTestCase):
|
||||
batch_shape, sparse_shape, dense_shape = split_list(shape, [n_batch, n_sparse])
|
||||
|
||||
approx_expected_num_nonzero = (
|
||||
np.ceil(0.2 * np.prod(sparse_shape))
|
||||
* np.prod(batch_shape) * np.prod(dense_shape))
|
||||
np.ceil(0.2 * math.prod(sparse_shape))
|
||||
* math.prod(batch_shape) * math.prod(dense_shape))
|
||||
num_nonzero = (mat_dense != 0).sum()
|
||||
self.assertAlmostEqual(int(num_nonzero), approx_expected_num_nonzero, delta=2)
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -35,7 +36,7 @@ config.parse_flags_with_absl()
|
||||
def rand_sparse(rng, nse=0.5, post=lambda x: x, rand_method=jtu.rand_default):
|
||||
def _rand_sparse(shape, dtype, nse=nse):
|
||||
rand = rand_method(rng)
|
||||
size = np.prod(shape).astype(int)
|
||||
size = math.prod(shape)
|
||||
if 0 <= nse < 1:
|
||||
nse = nse * size
|
||||
nse = min(size, int(nse))
|
||||
|
@ -278,9 +278,9 @@ class XMapTest(XMapTestCase):
|
||||
out_axes=({0: 'a', 1: 'b'}, ['c', ...]),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
a = jnp.arange(math.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
b = jnp.arange(math.prod(bshape)).reshape(bshape)
|
||||
c, d = fm(a, b)
|
||||
self.assertAllClose(c, a * 2)
|
||||
self.assertAllClose(d, b * 4)
|
||||
@ -292,9 +292,9 @@ class XMapTest(XMapTestCase):
|
||||
out_axes=(['b', ...], {0: 'c'}),
|
||||
axis_resources={'a': 'x', 'b': 'y', 'c': 'x'})
|
||||
ashape = (16, 8, 5)
|
||||
a = jnp.arange(np.prod(ashape)).reshape(ashape)
|
||||
a = jnp.arange(math.prod(ashape)).reshape(ashape)
|
||||
bshape = (2, 7)
|
||||
b = jnp.arange(np.prod(bshape)).reshape(bshape)
|
||||
b = jnp.arange(math.prod(bshape)).reshape(bshape)
|
||||
c, d = fm(a, b)
|
||||
self.assertAllClose(c, (a * 2).sum(0))
|
||||
self.assertAllClose(d, b * 4)
|
||||
@ -330,7 +330,7 @@ class XMapTest(XMapTestCase):
|
||||
fm = xmap(f, in_axes=['a', ...], out_axes=({}, {1: 'a'}),
|
||||
axis_resources={'a': ('x', 'y')})
|
||||
vshape = (4, 5)
|
||||
v = jnp.arange(np.prod(vshape)).reshape(vshape)
|
||||
v = jnp.arange(math.prod(vshape)).reshape(vshape)
|
||||
ans, ans2 = fm(v)
|
||||
self.assertAllClose(ans, (v * 2).sum(0))
|
||||
self.assertAllClose(ans2, v.T * 4)
|
||||
@ -344,7 +344,7 @@ class XMapTest(XMapTestCase):
|
||||
fyx = xmap(f, in_axes=['a', ...], out_axes={1: 'a'},
|
||||
axis_resources={'a': ('y', 'x')})
|
||||
vshape = (4, 5)
|
||||
v = jnp.arange(np.prod(vshape)).reshape(vshape)
|
||||
v = jnp.arange(math.prod(vshape)).reshape(vshape)
|
||||
zxy = fxy(v)
|
||||
zxy_op_sharding = zxy.sharding._to_xla_op_sharding(zxy.ndim)
|
||||
self.assertListEqual(zxy_op_sharding.tile_assignment_dimensions, [1, 4])
|
||||
@ -419,7 +419,7 @@ class XMapTest(XMapTestCase):
|
||||
return h(y)
|
||||
|
||||
xshape = (4, 2, 5)
|
||||
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
|
||||
x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape)
|
||||
y = f(x)
|
||||
self.assertAllClose(
|
||||
y, ((jnp.sin(x * 2) *
|
||||
@ -825,7 +825,7 @@ class XMapTestSPMD(SPMDTestMixin, XMapTest):
|
||||
in_axes=[None, 'a', 'b', ...], out_axes=(['a', 'b', ...], {}),
|
||||
axis_resources={'a': 'x', 'b': 'y'})
|
||||
xshape = (8, 2, 4, 5)
|
||||
x = jnp.arange(np.prod(xshape), dtype=float).reshape(xshape)
|
||||
x = jnp.arange(math.prod(xshape), dtype=float).reshape(xshape)
|
||||
hlo = f.lower(x).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
match = re.search(r"sharding={devices=\[([0-9,]+)\][0-9,]+}", hlo)
|
||||
self.assertIsNot(match, None)
|
||||
@ -1649,7 +1649,7 @@ class XMapErrorTest(jtu.JaxTestCase):
|
||||
return x
|
||||
return h(x)
|
||||
xshape = (2, 5, 6)
|
||||
x = jnp.arange(np.prod(xshape)).reshape(xshape)
|
||||
x = jnp.arange(math.prod(xshape)).reshape(xshape)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Changing the physical mesh is not allowed.*"):
|
||||
f(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user