mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present. PiperOrigin-RevId: 513011144
This commit is contained in:
parent
4f48f94649
commit
8fb1fd318d
@ -18,13 +18,15 @@ python3 pmap_benchmark.py
|
||||
|
||||
To make it run faster, set env var TARGET_TOTAL_SECS to a low number (e.g. 2).
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from absl import app
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import pmap
|
||||
from jax.config import config
|
||||
from jax._src.util import prod
|
||||
|
||||
from benchmarks import benchmark
|
||||
|
||||
@ -118,7 +120,7 @@ def sharded_device_array_indexing_benchmark():
|
||||
nshards = min(8, jax.local_device_count())
|
||||
shape = (nshards, 8, 8)
|
||||
def benchmark_fn():
|
||||
arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
|
||||
arr = pmap(lambda x: x)(jnp.arange(math.prod(shape)).reshape(shape))
|
||||
indices = indices_fn()
|
||||
for idx in indices:
|
||||
arr[idx]
|
||||
|
@ -26,6 +26,7 @@ import collections
|
||||
import functools
|
||||
from functools import partial
|
||||
import inspect
|
||||
import math
|
||||
import typing
|
||||
from typing import (Any, Callable, Generator, Hashable, Iterable, List, Literal,
|
||||
NamedTuple, Optional, Sequence, Tuple, TypeVar, Union, overload)
|
||||
@ -65,7 +66,7 @@ from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, split_list,
|
||||
wrap_name, cache, wraps, HashableFunction,
|
||||
weakref_lru_cache)
|
||||
|
||||
@ -1034,7 +1035,7 @@ def xla_computation(fun: Callable,
|
||||
if axis_env is None:
|
||||
return xla.AxisEnv(nreps, (), ())
|
||||
else:
|
||||
nreps = nreps * prod(size for name, size in axis_env)
|
||||
nreps = nreps * math.prod(size for name, size in axis_env)
|
||||
names, sizes = unzip2(axis_env)
|
||||
return xla.AxisEnv(nreps, names, sizes)
|
||||
|
||||
@ -3232,7 +3233,7 @@ class ShapeDtypeStruct:
|
||||
self.sharding = sharding
|
||||
self.named_shape = {} if named_shape is None else dict(named_shape)
|
||||
|
||||
size = property(lambda self: prod(self.shape))
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import operator as op
|
||||
import numpy as np
|
||||
import functools
|
||||
@ -26,7 +27,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.config import config
|
||||
from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.util import safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import api
|
||||
@ -208,7 +209,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.shape)
|
||||
return math.prod(self.shape)
|
||||
|
||||
@property
|
||||
def sharding(self):
|
||||
@ -552,12 +553,13 @@ def make_array_from_callback(
|
||||
|
||||
Example:
|
||||
|
||||
>>> import math
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
...
|
||||
>>> input_shape = (8, 8)
|
||||
>>> global_input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
>>> global_input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> inp_sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
||||
...
|
||||
@ -601,6 +603,7 @@ def make_array_from_single_device_arrays(
|
||||
|
||||
Example:
|
||||
|
||||
>>> import math
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> import numpy as np
|
||||
@ -608,7 +611,7 @@ def make_array_from_single_device_arrays(
|
||||
>>> shape = (8, 8)
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> sharding = jax.sharding.NamedSharding(global_mesh, P('x', 'y'))
|
||||
>>> inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
>>> inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
...
|
||||
>>> arrays = [
|
||||
... jax.device_put(inp_data[index], d)
|
||||
|
@ -22,6 +22,7 @@ from functools import partial, partialmethod, total_ordering
|
||||
import gc
|
||||
import inspect
|
||||
import itertools as it
|
||||
import math
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
import threading
|
||||
@ -44,7 +45,7 @@ from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
from jax._src import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
|
||||
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
|
||||
tuple_delete, as_hashable_function,
|
||||
HashableFunction, HashableWrapper, weakref_lru_cache)
|
||||
import jax._src.pretty_printer as pp
|
||||
@ -1481,7 +1482,7 @@ class ShapedArray(UnshapedArray):
|
||||
return ShapedArray(shape, dtype, weak_type, named_shape)
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
|
||||
broadcast: ClassVar[Optional[aval_method]] = None
|
||||
transpose: ClassVar[Optional[aval_method]] = None
|
||||
@ -1628,7 +1629,7 @@ class DShapedArray(UnshapedArray):
|
||||
self.weak_type = weak_type
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
|
||||
def str_short(self, short_dtypes=False) -> str:
|
||||
del short_dtypes # ignored
|
||||
|
@ -15,6 +15,7 @@
|
||||
from collections import Counter
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict, NamedTuple
|
||||
|
||||
@ -26,7 +27,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla, mlir
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
@ -87,7 +88,7 @@ def _get_shard_indices_replica_ids_uncached(
|
||||
out[device] = (index, replica_id)
|
||||
|
||||
shard_shape = get_shard_shape(global_shape, global_mesh, mesh_axes)
|
||||
expected_unique_shards = prod(
|
||||
expected_unique_shards = math.prod(
|
||||
[g // s for g, s in safe_zip(global_shape, shard_shape) if g != 0 or s != 0])
|
||||
if expected_unique_shards != unique_shards:
|
||||
raise RuntimeError(
|
||||
@ -104,7 +105,7 @@ def get_shard_shape(global_shape, global_mesh, mesh_axes) -> Shape:
|
||||
if not mesh_axis:
|
||||
chunk_size.append(size)
|
||||
elif isinstance(mesh_axis, tuple):
|
||||
m = prod([global_mesh.shape[ma] for ma in mesh_axis])
|
||||
m = math.prod([global_mesh.shape[ma] for ma in mesh_axis])
|
||||
chunk_size.append(size // m)
|
||||
else:
|
||||
chunk_size.append(size // global_mesh.shape[mesh_axis])
|
||||
@ -323,7 +324,7 @@ class GlobalDeviceArray:
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.shape)
|
||||
return math.prod(self.shape)
|
||||
|
||||
@property
|
||||
def mesh(self):
|
||||
@ -455,7 +456,7 @@ class GlobalDeviceArray:
|
||||
>>> global_input_shape = (8, 8)
|
||||
>>> mesh_axes = P('x', 'y')
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def cb(index):
|
||||
... return global_input_data[index]
|
||||
@ -505,7 +506,7 @@ class GlobalDeviceArray:
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = P('x')
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def batched_cb(indices):
|
||||
... assert len(indices) == len(global_mesh.local_devices)
|
||||
@ -555,7 +556,7 @@ class GlobalDeviceArray:
|
||||
>>> global_input_shape = (8, 2)
|
||||
>>> mesh_axes = P(('x', 'y'))
|
||||
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
|
||||
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
|
||||
>>> global_input_data = np.arange(math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
...
|
||||
>>> def cb(cb_inp):
|
||||
... dbs = []
|
||||
|
@ -37,6 +37,7 @@ import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import operator as op
|
||||
import sys
|
||||
import threading
|
||||
@ -78,7 +79,7 @@ from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
||||
wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction)
|
||||
@ -284,7 +285,8 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
if not has_unstacked:
|
||||
op_sharding_proto = sharding_spec_sharding_proto(self)
|
||||
return _op_sharding_to_numpy_indices(
|
||||
op_sharding_proto, shape, prod(self.mesh_shape)).reshape(self.mesh_shape)
|
||||
op_sharding_proto, shape, math.prod(self.mesh_shape)
|
||||
).reshape(self.mesh_shape)
|
||||
|
||||
axis_indices: List[Sequence[Index]] = []
|
||||
shard_indices_shape = []
|
||||
@ -312,7 +314,7 @@ def sharding_spec_indices(self, shape: Tuple[int, ...]) -> np.ndarray:
|
||||
# with each dimension having size equal to the number of shards across the corresponding
|
||||
# logical array dimension, and each element containing the multi-dimensional index that
|
||||
# is used to extract the corresponding shard of the logical array.
|
||||
shard_indices = np.empty([prod(shard_indices_shape)], dtype=np.object_)
|
||||
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
|
||||
for i, idxs in enumerate(it.product(*axis_indices)):
|
||||
shard_indices[i] = idxs
|
||||
shard_indices = shard_indices.reshape(shard_indices_shape)
|
||||
@ -766,7 +768,7 @@ class _ShardedDeviceArray(_SDA_BASE_CLASS): # type: ignore
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return prod(self.aval.shape)
|
||||
return math.prod(self.aval.shape)
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
@ -2252,7 +2254,7 @@ core.axis_substitution_rules[xla_pmap_p] = _pmap_axis_subst
|
||||
|
||||
def _unravel_index_hlo(axis_env):
|
||||
div = mlir.ir_constant(
|
||||
np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32))
|
||||
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
||||
return hlo.RemOp(
|
||||
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
|
||||
@ -4035,7 +4037,7 @@ class DynamicAxisEnv(list):
|
||||
|
||||
@property
|
||||
def nreps(self):
|
||||
return prod(frame.hard_size for frame in self)
|
||||
return math.prod(frame.hard_size for frame in self)
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
|
@ -19,6 +19,7 @@ import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
|
||||
@ -37,7 +38,7 @@ from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.util import (prod, safe_zip, safe_map, partition_list)
|
||||
from jax._src.util import (safe_zip, safe_map, partition_list)
|
||||
|
||||
from jax._src.typing import Shape
|
||||
|
||||
@ -311,7 +312,7 @@ def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
|
||||
if not isinstance(name, (list, tuple)):
|
||||
name = (name,)
|
||||
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
|
||||
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
|
||||
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
|
||||
assert not ragged
|
||||
mesh_spec = axis_env.sizes + (trailing_size,)
|
||||
return _axis_groups(mesh_spec, mesh_axes)
|
||||
@ -326,10 +327,10 @@ def _axis_groups(mesh_spec, mesh_axes):
|
||||
Returns:
|
||||
A tuple of replica groups (i.e. tuples containing replica ids).
|
||||
"""
|
||||
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
|
||||
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
|
||||
groups = np.reshape(
|
||||
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
||||
(prod(np.take(mesh_spec, mesh_axes)), -1))
|
||||
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
|
||||
return tuple(unsafe_map(tuple, groups.T))
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Union, Sequence
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +31,6 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
|
||||
from jax._src.util import prod
|
||||
|
||||
__all__ = [
|
||||
"fft",
|
||||
@ -150,7 +150,7 @@ def _irfft_transpose(t, fft_lengths):
|
||||
full(2.0, shape=(n - 2 + is_odd,)),
|
||||
full(1.0, shape=(1 - is_odd,))],
|
||||
dimension=0)
|
||||
scale = 1 / prod(fft_lengths)
|
||||
scale = 1 / math.prod(fft_lengths)
|
||||
out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x
|
||||
assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
|
||||
# Use JAX's convention for complex gradients
|
||||
|
@ -17,6 +17,7 @@ import enum
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
from typing import (Any, Callable, Optional, Sequence, Tuple, List,
|
||||
TypeVar, Union, cast as type_cast, overload)
|
||||
@ -70,7 +71,7 @@ from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
|
||||
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
|
||||
from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
|
||||
xb = xla_bridge
|
||||
@ -1422,7 +1423,7 @@ def collapse(operand: Array, start_dimension: int,
|
||||
collapsed (raveled) into a single dimension.
|
||||
"""
|
||||
lo, hi = start_dimension, stop_dimension
|
||||
size = prod(operand.shape[lo:hi])
|
||||
size = math.prod(operand.shape[lo:hi])
|
||||
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
||||
return reshape(operand, new_shape)
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
import inspect
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import cast, Any, Callable, List, Literal, Optional, Tuple, TypeVar, Union, overload
|
||||
import warnings
|
||||
|
||||
@ -50,7 +51,6 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import prod
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
@ -1313,7 +1313,7 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (geqrf); b/261671778")
|
||||
a_aval, taus_aval = ctx.avals_out
|
||||
*batch_dims, m, n = a_aval.shape
|
||||
batch = prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
if batch == 0 or m == 0 or n == 0:
|
||||
return mlir.full_like_aval(ctx, 0, a_aval), mlir.full_like_aval(ctx, 0, taus_aval)
|
||||
|
@ -12,11 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Any, Optional, Sequence, Tuple, Union, cast as type_cast
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.util import prod
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import convolution
|
||||
|
||||
@ -92,7 +92,7 @@ def conv_general_dilated_patches(
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
|
||||
spatial_size = prod(filter_shape)
|
||||
spatial_size = math.prod(filter_shape)
|
||||
n_channels = lhs_array.shape[lhs_spec[1]]
|
||||
|
||||
# Move separate `lhs` spatial locations into separate `rhs` channels.
|
||||
|
@ -17,6 +17,7 @@ Parallelization primitives.
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import string
|
||||
from typing import Sequence, Union
|
||||
import warnings
|
||||
@ -40,7 +41,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import (
|
||||
unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis)
|
||||
unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis)
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
@ -828,7 +829,7 @@ def psum_bind(*args, axes, axis_index_groups):
|
||||
assert not pos_axes
|
||||
size = len(axis_index_groups[0])
|
||||
else:
|
||||
size = prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
|
||||
size = math.prod([core.axis_frame(name).size for name in named_axes]) # type: ignore
|
||||
return tuple(lax._const(x, size) * pos_reduce(x) for x in args)
|
||||
return core.AxisPrimitive.bind(
|
||||
psum_p, *args, axes=axes, axis_index_groups=axis_index_groups)
|
||||
@ -1563,9 +1564,12 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
'`axis_index` translation rule does not support multiple axis names.')
|
||||
axis_name, = axis_name
|
||||
axis_pos = list(axis_env.names).index(axis_name)
|
||||
nreplicas = axis_env.nreps // prod(axis_env.sizes)
|
||||
div = mlir.ir_constant(np.array(nreplicas * prod(axis_env.sizes[axis_pos+1:]),
|
||||
dtype=np.uint32))
|
||||
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
|
||||
div = mlir.ir_constant(
|
||||
np.array(
|
||||
nreplicas * math.prod(axis_env.sizes[axis_pos + 1 :]), dtype=np.uint32
|
||||
)
|
||||
)
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(axis_context,
|
||||
|
@ -17,7 +17,7 @@ Common neural network layer initializers, consistent with definitions
|
||||
used in Keras and Sonnet.
|
||||
"""
|
||||
|
||||
|
||||
import math
|
||||
from typing import Any, Literal, Protocol, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@ -27,7 +27,6 @@ from jax import lax
|
||||
from jax import random
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import prod
|
||||
|
||||
KeyArray = random.KeyArray
|
||||
Array = Any
|
||||
@ -549,7 +548,7 @@ def orthogonal(scale: RealNumeric = 1.0,
|
||||
dtype = dtypes.canonicalize_dtype(dtype)
|
||||
if len(shape) < 2:
|
||||
raise ValueError("orthogonal initializer requires at least a 2D shape")
|
||||
n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
|
||||
n_rows, n_cols = math.prod(shape) // shape[column_axis], shape[column_axis]
|
||||
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
|
||||
A = random.normal(key, matrix_shape, dtype)
|
||||
Q, R = jnp.linalg.qr(A)
|
||||
|
@ -27,6 +27,7 @@ rules for the underlying :code:`lax` primitives.
|
||||
import builtins
|
||||
import collections
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
from typing import (
|
||||
@ -82,7 +83,7 @@ from jax._src.numpy.util import ( # noqa: F401
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DType, DTypeLike, Shape
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip,
|
||||
from jax._src.util import (unzip2, subvals, safe_zip,
|
||||
ceil_of_ratio, partition_list,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
from jax._src.array import ArrayImpl
|
||||
@ -534,7 +535,7 @@ def histogramdd(sample: ArrayLike, bins: Union[ArrayLike, List[ArrayLike]] = 10,
|
||||
dedges = [diff(bin_edges) for bin_edges in bin_edges_by_dim]
|
||||
|
||||
xy = ravel_multi_index(tuple(bin_idx_by_dim), nbins, mode='clip')
|
||||
hist = bincount(xy, weights, length=_prod(nbins))
|
||||
hist = bincount(xy, weights, length=math.prod(nbins))
|
||||
hist = reshape(hist, nbins)
|
||||
core = D*(slice(1, -1),)
|
||||
hist = hist[core]
|
||||
@ -914,7 +915,7 @@ def resize(a: ArrayLike, new_shape: Shape) -> Array:
|
||||
|
||||
arr = ravel(a)
|
||||
|
||||
new_size = _prod(new_shape)
|
||||
new_size = math.prod(new_shape)
|
||||
if arr.size == 0 or new_size == 0:
|
||||
return zeros_like(arr, shape=new_shape)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import builtins
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from typing import overload, Any, Callable, Literal, Optional, Sequence, Tuple, Union
|
||||
import warnings
|
||||
@ -31,7 +32,7 @@ from jax._src.numpy.util import (
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.util import (
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod)
|
||||
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
|
||||
|
||||
|
||||
_all = builtins.all
|
||||
@ -371,7 +372,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = Non
|
||||
if axis is None:
|
||||
weights_sum = lax.full((), core.dimension_as_value(a.size), dtype=avg.dtype)
|
||||
elif isinstance(axis, tuple):
|
||||
weights_sum = lax.full_like(avg, _prod(core.dimension_as_value(a.shape[d]) for d in axis))
|
||||
weights_sum = lax.full_like(avg, math.prod(core.dimension_as_value(a.shape[d]) for d in axis))
|
||||
else:
|
||||
weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index]
|
||||
else:
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
from textwrap import dedent as _dedent
|
||||
from typing import Optional, Tuple, Union
|
||||
@ -31,7 +32,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
sort, where, zeros)
|
||||
from jax._src.numpy.util import _check_arraylike, _wraps
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import prod as _prod
|
||||
|
||||
|
||||
_lax_const = lax_internal._const
|
||||
@ -230,11 +230,11 @@ def _unique_sorted_mask(ar: Array, axis: int) -> Tuple[Array, Array, Array]:
|
||||
# is fixed to match numpy.
|
||||
aux = where(isnan(aux), _lax_const(aux, np.nan), aux)
|
||||
size, *out_shape = aux.shape
|
||||
if _prod(out_shape) == 0:
|
||||
if math.prod(out_shape) == 0:
|
||||
size = 1
|
||||
perm = zeros(1, dtype=int)
|
||||
else:
|
||||
perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1])
|
||||
perm = lexsort(aux.reshape(size, math.prod(out_shape)).T[::-1])
|
||||
aux = aux[perm]
|
||||
if aux.size:
|
||||
if dtypes.issubdtype(aux.dtype, np.inexact):
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import abc
|
||||
from functools import partial, reduce
|
||||
import math
|
||||
import operator as op
|
||||
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence
|
||||
|
||||
@ -44,7 +45,7 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, PmapSharding, GSPMDSharding)
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
from jax._src.util import canonicalize_axis, safe_map, safe_zip
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -1141,7 +1142,7 @@ def threefry_random_bits(key: jax.Array, bit_width, shape):
|
||||
return _threefry_random_bits_original(key, bit_width, shape)
|
||||
|
||||
def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
|
||||
if all(core.is_constant_dim(d) for d in shape) and prod(shape) > 2 ** 64:
|
||||
if all(core.is_constant_dim(d) for d in shape) and math.prod(shape) > 2 ** 64:
|
||||
raise NotImplementedError('random bits array of size exceeding 2 ** 64')
|
||||
|
||||
k1, k2 = key
|
||||
@ -1160,7 +1161,7 @@ def _threefry_random_bits_partitionable(key: jax.Array, bit_width, shape):
|
||||
|
||||
@partial(jit, static_argnums=(1, 2), inline=True)
|
||||
def _threefry_random_bits_original(key: jax.Array, bit_width, shape):
|
||||
size = prod(shape)
|
||||
size = math.prod(shape)
|
||||
# Compute ceil(bit_width * size / 32) in a way that is friendly to shape
|
||||
# polymorphism
|
||||
max_count, r = divmod(bit_width * size, 32)
|
||||
|
@ -14,8 +14,9 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Sequence, Union
|
||||
import math
|
||||
from operator import index
|
||||
from typing import Optional, Sequence, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -40,7 +41,7 @@ from jax._src.numpy.lax_numpy import (
|
||||
_arraylike, _check_arraylike, _convert_and_clip_integer,
|
||||
_promote_dtypes_inexact)
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import prod, canonicalize_axis
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
|
||||
RealArray = ArrayLike
|
||||
@ -507,7 +508,7 @@ def choice(key: KeyArray,
|
||||
else:
|
||||
axis = canonicalize_axis(axis, arr.ndim)
|
||||
n_inputs = arr.shape[axis]
|
||||
n_draws = prod(shape)
|
||||
n_draws = math.prod(shape)
|
||||
if n_draws == 0:
|
||||
return jnp.zeros(shape, dtype=arr.dtype)
|
||||
if n_inputs <= 0:
|
||||
@ -1025,7 +1026,7 @@ def _gamma_grad(sample, a, *, log_space):
|
||||
def _gamma_impl(key, a, *, log_space, use_vmap=False):
|
||||
# split key to match the shape of a
|
||||
a_shape = jnp.shape(a)
|
||||
split_count = prod(a_shape[key.ndim:])
|
||||
split_count = math.prod(a_shape[key.ndim:])
|
||||
keys = key.flatten()
|
||||
keys = vmap(_split, in_axes=(0, None))(keys, split_count)
|
||||
keys = keys.flatten()
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import jax
|
||||
@ -24,7 +25,7 @@ from jax._src.api import vmap
|
||||
from jax._src.numpy.lax_numpy import _check_arraylike
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis, prod
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
import scipy
|
||||
|
||||
@ -80,7 +81,7 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
x = jnp.moveaxis(x, axis, 0)
|
||||
x = x.reshape(x.shape[0], prod(x.shape[1:]))
|
||||
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
|
||||
vals, counts = vmap(_mode_helper, in_axes=1)(x)
|
||||
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Module for state types."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from jax._src import core
|
||||
@ -21,7 +22,7 @@ from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_map, safe_zip, prod
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
@ -89,7 +90,7 @@ class AbstractRef(core.AbstractValue, Generic[Aval]):
|
||||
return AbstractRef(self.inner_aval.join(other.inner_aval))
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
size = property(lambda self: math.prod(self.shape))
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -17,6 +17,7 @@ import inspect
|
||||
import io
|
||||
import functools
|
||||
from functools import partial
|
||||
import math
|
||||
import re
|
||||
import os
|
||||
import tempfile
|
||||
@ -46,7 +47,7 @@ from jax._src.config import (flags, bool_env, config,
|
||||
raise_persistent_cache_errors,
|
||||
persistent_cache_min_compile_time_secs)
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
|
||||
from jax._src.util import prod, unzip2
|
||||
from jax._src.util import unzip2
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
||||
@ -668,7 +669,7 @@ def rand_int(rng, low=0, high=None):
|
||||
|
||||
def rand_unique_int(rng, high=None):
|
||||
def fn(shape, dtype):
|
||||
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
|
||||
return rng.choice(np.arange(high or math.prod(shape), dtype=dtype),
|
||||
size=shape, replace=False)
|
||||
return fn
|
||||
|
||||
@ -755,7 +756,7 @@ def sample_product_testcases(*args, **kw):
|
||||
"""Non-decorator form of sample_product."""
|
||||
args = [list(arg) for arg in args]
|
||||
kw = [(k, list(v)) for k, v in kw.items()]
|
||||
n = prod(len(a) for a in args) * prod(len(v) for _, v in kw)
|
||||
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
||||
testcases = []
|
||||
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)):
|
||||
testcase = {}
|
||||
@ -1054,7 +1055,7 @@ def with_mesh(named_shape: MeshSpec) -> Generator[None, None, None]:
|
||||
"""Test utility for setting up meshes given mesh data from `schedules`."""
|
||||
# This is similar to the `with_mesh` function above, but isn't a decorator.
|
||||
axis_names, shape = unzip2(named_shape)
|
||||
size = prod(shape)
|
||||
size = math.prod(shape)
|
||||
local_devices = list(api.local_devices())
|
||||
if len(local_devices) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} local devices")
|
||||
@ -1094,7 +1095,7 @@ def restore_spmd_manual_lowering_flag():
|
||||
config.update('experimental_xmap_spmd_lowering_manual', old_spmd_manual_lowering_flag)
|
||||
|
||||
def create_global_mesh(mesh_shape, axis_names):
|
||||
size = prod(mesh_shape)
|
||||
size = math.prod(mesh_shape)
|
||||
if len(api.devices()) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
devices = sorted(api.devices(), key=lambda d: d.id)
|
||||
|
@ -269,12 +269,6 @@ def weakref_lru_cache(call: Callable, maxsize=2048):
|
||||
"""
|
||||
return xc.weakref_lru_cache(config._trace_context, call, maxsize)
|
||||
|
||||
def prod(xs):
|
||||
out = 1
|
||||
for x in xs:
|
||||
out *= x
|
||||
return out
|
||||
|
||||
class Unhashable:
|
||||
__slots__ = ["val"]
|
||||
|
||||
|
@ -83,6 +83,7 @@ TODO:
|
||||
- Support RNNs other than LSTM.
|
||||
"""
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import jax
|
||||
@ -92,7 +93,6 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.custom_derivatives import custom_vjp
|
||||
from jax._src.typing import Array, Shape
|
||||
from jax._src.util import prod
|
||||
import jax.numpy as jnp
|
||||
try:
|
||||
from jax._src.lib import gpu_rnn
|
||||
@ -161,7 +161,7 @@ def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int,
|
||||
"""Get param count in LSTM."""
|
||||
layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
param_count = sum([prod(shape) for shape in layer_shapes])
|
||||
param_count = sum([math.prod(shape) for shape in layer_shapes])
|
||||
return param_count
|
||||
|
||||
|
||||
@ -201,7 +201,7 @@ def unpack_lstm_weights(
|
||||
for w_kind in [W_ih, W_hh]:
|
||||
shape = flat_shapes[flat_shapes_offset]
|
||||
flat_shapes_offset += 1
|
||||
num_elems = prod(shape)
|
||||
num_elems = math.prod(shape)
|
||||
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
|
||||
w_offsets += num_elems
|
||||
|
||||
@ -211,7 +211,7 @@ def unpack_lstm_weights(
|
||||
for w_kind in [b_ih, b_hh]:
|
||||
shape = flat_shapes[flat_shapes_offset]
|
||||
flat_shapes_offset += 1
|
||||
num_elems = prod(shape)
|
||||
num_elems = math.prod(shape)
|
||||
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
|
||||
w_offsets += num_elems
|
||||
return W_ih, W_hh, b_ih, b_hh
|
||||
|
@ -17,6 +17,7 @@ import enum
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
import itertools as it
|
||||
import math
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
|
||||
Set, Tuple, TypeVar, Union, Protocol)
|
||||
@ -37,7 +38,7 @@ from jax._src import util
|
||||
from jax._src.core import Tracer
|
||||
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
|
||||
windowed_reductions, fft, linalg)
|
||||
from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
|
||||
from jax._src.util import (HashableFunction, unzip2, as_hashable_function,
|
||||
memoize, partition_list, merge_lists)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax.interpreters import batching
|
||||
@ -150,7 +151,7 @@ def _check_specs_vs_args(
|
||||
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
|
||||
raise ValueError(msg)
|
||||
in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat))
|
||||
fail = [a if any(a.shape[d] % prod(mesh.shape[n] for n in ns)
|
||||
fail = [a if any(a.shape[d] % math.prod(mesh.shape[n] for n in ns)
|
||||
for d, ns in names.items()) else no_fail
|
||||
for a, names in zip(in_avals, in_names_flat)]
|
||||
if any(f is not no_fail for f in fail):
|
||||
@ -201,10 +202,10 @@ def _spec_divisibility_error(
|
||||
f"parameter '{list(ba.arguments.keys())[arg_key.key]}',")
|
||||
names = _canonicalize_spec(spec)
|
||||
for d, ns in names.items():
|
||||
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
|
||||
if aval.shape[d] % math.prod(mesh.shape[n] for n in ns):
|
||||
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
|
||||
total = 'total ' if len(ns) > 1 else ''
|
||||
sz = prod(mesh.shape[n] for n in ns)
|
||||
sz = math.prod(mesh.shape[n] for n in ns)
|
||||
msgs.append(
|
||||
f"args{fail_key.pprint()} of shape {aval.str_short()}{extra} "
|
||||
f"corresponds to in_specs{spec_key.pprint()} of value {spec}, "
|
||||
@ -382,7 +383,7 @@ pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
return aval.update(tuple(sz // math.prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)))
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj): add table with handlers
|
||||
@ -390,7 +391,7 @@ def _shard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
) -> core.AbstractValue:
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
return aval.update(tuple(sz * math.prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape)),
|
||||
named_shape={k: v for k, v in aval.named_shape.items()
|
||||
if k not in mesh.shape})
|
||||
@ -880,7 +881,7 @@ def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
|
||||
check_rep):
|
||||
mb_div = lambda x, y: x / y if y != 1 else x
|
||||
out_cts = [ad.Zero(_shard_aval(mesh, ns, x.aval)) if type(x) is ad.Zero
|
||||
else mb_div(x, prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
else mb_div(x, math.prod(map(mesh.shape.get, _unmentioned(mesh, ns))))
|
||||
for ns, x in zip(out_names, out_cts)]
|
||||
args = [x if type(x) is not ad.UndefinedPrimal else
|
||||
ad.UndefinedPrimal(_shard_aval(mesh, ns, x.aval))
|
||||
|
@ -14,11 +14,11 @@
|
||||
|
||||
"""Base JAX Sparse object."""
|
||||
import abc
|
||||
import math
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class JAXSparse(abc.ABC):
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return util.prod(self.shape)
|
||||
return math.prod(self.shape)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
|
@ -171,7 +171,6 @@ from jax._src.lax.lax import (
|
||||
population_count_p as population_count_p,
|
||||
pow as pow,
|
||||
pow_p as pow_p,
|
||||
prod as prod,
|
||||
random_gamma_grad as random_gamma_grad,
|
||||
random_gamma_grad_p as random_gamma_grad_p,
|
||||
real as real,
|
||||
@ -369,3 +368,5 @@ from jax.lax import linalg as linalg
|
||||
|
||||
from jax._src.pjit import with_sharding_constraint
|
||||
from jax._src.dispatch import device_put_p
|
||||
|
||||
from math import prod # TODO(phawkins): remove this accidental export
|
||||
|
@ -13,9 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import operator
|
||||
import math
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
@ -60,8 +59,6 @@ def _real_type(dtype):
|
||||
"""Returns the real equivalent of 'dtype'."""
|
||||
return np.finfo(dtype).dtype
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
"""LU decomposition."""
|
||||
@ -71,7 +68,7 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
if batch > 1 and m == n and m // batch <= 128:
|
||||
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
|
||||
@ -118,7 +115,7 @@ def _geqrf_hlo(platform, gpu_solver, dtype, a):
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
lwork, opaque = gpu_solver.build_geqrf_descriptor(
|
||||
np.dtype(dtype), batch, m, n)
|
||||
@ -156,7 +153,7 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
|
||||
np.dtype(dtype), batch, m, n)
|
||||
@ -220,7 +217,7 @@ def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
tau_dims = ir.RankedTensorType(tau.type).shape
|
||||
assert tau_dims[:-1] == dims[:-2]
|
||||
@ -266,7 +263,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
assert m == n
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
batch = _prod(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
|
||||
if have_jacobi_solver and n <= 32:
|
||||
@ -317,7 +314,7 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
m, n = dims[-2:]
|
||||
batch_dims = tuple(dims[:-2])
|
||||
num_bd = len(batch_dims)
|
||||
b = _prod(batch_dims)
|
||||
b = math.prod(batch_dims)
|
||||
if ir.ComplexType.isinstance(a_type.element_type):
|
||||
singular_vals_type = ir.ComplexType(a_type.element_type).element_type
|
||||
else:
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
@ -21,7 +22,6 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod
|
||||
|
||||
from jax.config import config
|
||||
|
||||
@ -104,7 +104,7 @@ class AnnTest(jtu.JaxTestCase):
|
||||
is_max_k=[True, False],
|
||||
)
|
||||
def test_autodiff(self, shape, dtype, k, is_max_k):
|
||||
vals = np.arange(prod(shape), dtype=dtype)
|
||||
vals = np.arange(math.prod(shape), dtype=dtype)
|
||||
vals = self.rng().permutation(vals).reshape(shape)
|
||||
if is_max_k:
|
||||
fn = lambda vs: lax.approx_max_k(vs, k=k)[0]
|
||||
|
@ -14,6 +14,7 @@
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
@ -27,7 +28,7 @@ from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.util import safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
@ -72,7 +73,7 @@ def tearDownModule():
|
||||
|
||||
def create_array(shape, sharding, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(prod(shape)).reshape(shape)
|
||||
global_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
return array.make_array_from_callback(
|
||||
shape, sharding, lambda idx: global_data[idx]), global_data
|
||||
@ -310,7 +311,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
s = sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
di_map = s.devices_indices_map(shape)
|
||||
bufs = [jax.device_put(inp_data[di_map[d]], d)
|
||||
for d in jax.local_devices()]
|
||||
@ -333,7 +334,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
# sharding device ids = {0, 1}
|
||||
s = sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
# _arrays device ids = {2, 3}
|
||||
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[2:4]]
|
||||
with self.assertRaisesRegex(
|
||||
@ -349,7 +350,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
# Sharding device ids = {0, 1}
|
||||
s = sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
# _arrays device ids = {0, 0}
|
||||
bufs = [jax.device_put(inp_data, jax.devices()[0]) for _ in range(2)]
|
||||
with self.assertRaisesRegex(
|
||||
@ -366,7 +367,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
mesh = jax.sharding.Mesh(np.array([jax.devices()[1], jax.devices()[2]]), ('x'))
|
||||
# sharding device ids = {1, 2}
|
||||
s = sharding.NamedSharding(mesh, P('x'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
# _arrays device ids = {0, 1}
|
||||
bufs = [jax.device_put(inp_data, d) for d in jax.devices()[:2]]
|
||||
with self.assertRaisesRegex(
|
||||
@ -389,7 +390,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
mps = sharding.NamedSharding(mesh, pspec)
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
str_expected_shard_shape = str(expected_shard_shape).replace(
|
||||
r"(", r"\(").replace(r")", r"\)")
|
||||
@ -403,7 +404,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
|
||||
s = sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
||||
indices = s.devices_indices_map(shape)
|
||||
bufs = [jax.device_put(inp_data[indices[d]], d) for d in mesh.local_devices]
|
||||
with self.assertRaisesRegex(
|
||||
@ -725,7 +726,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.skipTest('Test needs >= 2 devices.')
|
||||
|
||||
shape = (2, 2)
|
||||
num_elements = prod(shape)
|
||||
num_elements = math.prod(shape)
|
||||
inp_data = np.arange(num_elements).reshape(shape)
|
||||
out = jax.pmap(lambda x: x)(inp_data)
|
||||
self.assertIsInstance(out.sharding, sharding.PmapSharding)
|
||||
@ -954,7 +955,7 @@ class RngShardingTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
s = sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
n = prod(global_shape)
|
||||
n = math.prod(global_shape)
|
||||
global_x = jnp.arange(n).astype('uint32').reshape(global_shape)
|
||||
x = array.make_array_from_callback(global_x.shape, s, lambda i: global_x[i])
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from functools import partial
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
@ -29,7 +30,6 @@ from jax.experimental.maps import xmap
|
||||
from jax.experimental.pjit import pjit
|
||||
import jax
|
||||
from jax import jit, lax, pmap
|
||||
from jax._src.util import prod
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
@ -284,11 +284,11 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape), dtype=np.int64).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.int64).reshape(shape)
|
||||
f(x, x + 1)
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_directory, 1)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f(x, x + 1)
|
||||
files_in_directory = len(os.listdir(tmpdir))
|
||||
self.assertEqual(files_in_directory, 2)
|
||||
|
@ -13,15 +13,17 @@
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import math
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.util import safe_zip
|
||||
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax.sharding import Mesh
|
||||
@ -34,7 +36,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(prod(global_shape)).reshape(global_shape)
|
||||
global_data = np.arange(math.prod(global_shape)).reshape(global_shape)
|
||||
|
||||
return GlobalDeviceArray.from_callback(
|
||||
global_shape, global_mesh, mesh_axes, lambda idx: global_data[idx]), global_data
|
||||
@ -75,7 +77,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -128,7 +130,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
|
||||
global_input_shape = (8, 4, 2)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -163,7 +165,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
expected_replica_ids):
|
||||
global_mesh = jtu.create_global_mesh((8,), ('x'))
|
||||
global_input_shape = (16,)
|
||||
global_input_data = np.arange(prod(global_input_shape)).reshape(-1)
|
||||
global_input_data = np.arange(math.prod(global_input_shape)).reshape(-1)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -214,7 +216,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -240,7 +242,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(indices):
|
||||
self.assertEqual(len(indices), len(global_mesh.local_devices))
|
||||
@ -260,7 +262,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
def cb(cb_inp):
|
||||
self.assertLen(cb_inp, 4)
|
||||
@ -288,7 +290,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
gda = GlobalDeviceArray.from_callback(
|
||||
@ -307,7 +309,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None,)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
input_gda = GlobalDeviceArray.from_callback(
|
||||
@ -340,7 +342,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
indices = get_shard_indices(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
dbs = [
|
||||
@ -358,7 +360,7 @@ class GDATest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(('x', 'y'))
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
@ -16,6 +16,7 @@
|
||||
import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
from unittest import SkipTest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -28,7 +29,6 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax._src.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -832,7 +832,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -847,7 +847,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
k=[1, 3],
|
||||
)
|
||||
def testTopKGrad(self, shape, dtype, k):
|
||||
flat_values = np.arange(prod(shape), dtype=dtype)
|
||||
flat_values = np.arange(math.prod(shape), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
fun = lambda vs: lax.top_k(vs, k=k)[0]
|
||||
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
@ -19,6 +19,7 @@ from functools import partial
|
||||
import inspect
|
||||
import io
|
||||
import itertools
|
||||
import math
|
||||
from typing import cast, Iterator, Optional, List, Tuple
|
||||
import unittest
|
||||
from unittest import SkipTest
|
||||
@ -46,7 +47,7 @@ from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import array
|
||||
|
||||
from jax.config import config
|
||||
@ -1358,7 +1359,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if shape in scalar_shapes or len(shape) == 0:
|
||||
cond_shape = (0,)
|
||||
elif axis is None:
|
||||
cond_shape = (prod(shape),)
|
||||
cond_shape = (math.prod(shape),)
|
||||
else:
|
||||
cond_shape = (shape[axis],)
|
||||
|
||||
@ -1394,7 +1395,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if shape in scalar_shapes or len(shape) == 0:
|
||||
cond_shape = (0,)
|
||||
elif axis is None:
|
||||
cond_shape = (prod(shape),)
|
||||
cond_shape = (math.prod(shape),)
|
||||
else:
|
||||
cond_shape = (shape[axis],)
|
||||
|
||||
@ -1488,7 +1489,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
[dict(shape=shape, axis=axis, idx=idx)
|
||||
for shape in nonempty_nonscalar_array_shapes
|
||||
for axis in [None] + list(range(-len(shape), len(shape)))
|
||||
for idx in (range(-prod(shape), prod(shape))
|
||||
for idx in (range(-math.prod(shape), math.prod(shape))
|
||||
if axis is None else
|
||||
range(-shape[axis], shape[axis]))],
|
||||
dtype=all_dtypes,
|
||||
@ -3372,7 +3373,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
idx_shape=all_shapes,
|
||||
)
|
||||
def testUnravelIndex(self, shape, idx_shape, dtype):
|
||||
size = prod(shape)
|
||||
size = math.prod(shape)
|
||||
rng = jtu.rand_int(self.rng(), low=-((2 * size) // 3), high=(2 * size) // 3)
|
||||
|
||||
def np_fun(index, shape):
|
||||
|
@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
import operator
|
||||
import types
|
||||
import unittest
|
||||
@ -44,7 +45,6 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import lax_reference
|
||||
from jax._src.util import prod
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
|
||||
@ -1998,7 +1998,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -2035,7 +2035,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
flat_keys = np.arange(math.prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -2051,7 +2051,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
)
|
||||
def testTopK(self, shape, dtype, k):
|
||||
def args_maker():
|
||||
flat_values = np.arange(prod(shape), dtype=dtype)
|
||||
flat_values = np.arange(math.prod(shape), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
return [values]
|
||||
def reference_top_k(x):
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import math
|
||||
from typing import Optional, cast
|
||||
import unittest
|
||||
|
||||
@ -32,7 +33,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src.internal_test_util import lax_test_util
|
||||
from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import prod, safe_map, safe_zip
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -667,7 +668,7 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
# Note also that we chose 3 * 5 * 3 * 5 such that it fits in the range of
|
||||
# values a bfloat16 can represent exactly to avoid ties.
|
||||
def testTopK(self, shape, dtype, k, bdims):
|
||||
rng = jtu.rand_int(self.rng(), high=prod(shape))
|
||||
rng = jtu.rand_int(self.rng(), high=math.prod(shape))
|
||||
# _CheckBatching doesn't work with tuple outputs, so test outputs separately.
|
||||
op1 = lambda x: lax.top_k(x, k=k)[0]
|
||||
self._CheckBatching(op1, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
@ -16,6 +16,7 @@ import os
|
||||
import re
|
||||
from functools import partial, lru_cache
|
||||
import logging
|
||||
import math
|
||||
import threading
|
||||
import unittest
|
||||
from collections import OrderedDict, namedtuple
|
||||
@ -51,7 +52,7 @@ from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.util import prod, curry, unzip2, safe_zip
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -85,7 +86,7 @@ def create_gda(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
dtype=np.float32):
|
||||
if global_data is None:
|
||||
global_data = np.arange(
|
||||
prod(global_shape), dtype=dtype).reshape(global_shape)
|
||||
math.prod(global_shape), dtype=dtype).reshape(global_shape)
|
||||
|
||||
if isinstance(mesh_axes, Sharding):
|
||||
mesh_axes = mesh_axes.spec
|
||||
@ -98,7 +99,7 @@ def create_array(global_shape, global_mesh, mesh_axes, global_data=None,
|
||||
dtype=np.float32):
|
||||
if global_data is None:
|
||||
global_data = np.arange(
|
||||
prod(global_shape), dtype=dtype).reshape(global_shape)
|
||||
math.prod(global_shape), dtype=dtype).reshape(global_shape)
|
||||
|
||||
if isinstance(mesh_axes, Sharding):
|
||||
sharding = mesh_axes
|
||||
@ -144,7 +145,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x
|
||||
|
||||
shape = (2, 2)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
actual = f(x)
|
||||
expected = x
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -164,7 +165,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
actual = f(x, x + 1)
|
||||
expected = x + (x + 1)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -182,7 +183,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
if config.jax_array:
|
||||
out = jax.jit(f)(x, x + 1)
|
||||
self.assertArraysEqual(out, x + x + 1)
|
||||
@ -205,7 +206,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return jnp.pad(out, [[0, 1]])
|
||||
|
||||
shape = (4,)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
actual = f(x, x + 1)
|
||||
expected = x + (x + 1)
|
||||
self.assertAllClose(actual[:3], expected[:3], check_dtypes=False)
|
||||
@ -222,7 +223,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with jtu.create_global_mesh((2,), ('x')) as mesh:
|
||||
actual = f(x, x + 1)
|
||||
expected = x + (x + 1)
|
||||
@ -281,7 +282,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
def testMeshDecorator(self):
|
||||
x = jnp.arange(8)
|
||||
mesh_shape = (2, 2)
|
||||
size = prod(mesh_shape)
|
||||
size = math.prod(mesh_shape)
|
||||
if len(jax.devices()) < size:
|
||||
raise unittest.SkipTest(f"Test requires {size} global devices.")
|
||||
mesh_devices = np.array(jax.devices()[:size]).reshape(mesh_shape)
|
||||
@ -347,7 +348,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return y * 2
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
expected = (x + 1) * 2
|
||||
actual = f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -374,7 +375,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return y * 2
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
expected = (x + 1) * 2
|
||||
actual = f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -402,7 +403,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
y = with_sharding_constraint(y, ops)
|
||||
return y * 2
|
||||
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
expected = (x + 1) * 2
|
||||
actual = f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -426,7 +427,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x
|
||||
|
||||
shape = (8, 8)
|
||||
v = np.arange(prod(shape)).reshape(shape)
|
||||
v = np.arange(math.prod(shape)).reshape(shape)
|
||||
x = [{"a": v, "b": v * 2}, v * 3]
|
||||
actual = f(x)
|
||||
|
||||
@ -458,7 +459,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x
|
||||
|
||||
shape = (8, 8)
|
||||
v = np.arange(prod(shape)).reshape(shape)
|
||||
v = np.arange(math.prod(shape)).reshape(shape)
|
||||
x = [{"a": v, "b": v * 2}, v * 3]
|
||||
actual = f(x)
|
||||
|
||||
@ -487,7 +488,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x
|
||||
|
||||
shape = (2, 8, 8)
|
||||
v = np.arange(prod(shape)).reshape(shape)
|
||||
v = np.arange(math.prod(shape)).reshape(shape)
|
||||
x = [{'a': v, 'b': v * 2}, v * 3]
|
||||
actual = f(x)
|
||||
|
||||
@ -513,7 +514,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x
|
||||
|
||||
shape = (2, 8, 8)
|
||||
v = np.arange(prod(shape)).reshape(shape)
|
||||
v = np.arange(math.prod(shape)).reshape(shape)
|
||||
x = [{'a': v, 'b': v * 2}, v * 3]
|
||||
|
||||
mlir_str = str(f.lower(x).compiler_ir())
|
||||
@ -1084,7 +1085,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
input_shape = (8, 4)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
seeds = jnp.arange(
|
||||
prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
||||
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
||||
|
||||
with mesh:
|
||||
def make_keys(seeds):
|
||||
@ -1179,7 +1180,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1214,7 +1215,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1288,7 +1289,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1321,7 +1322,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
@jtu.with_mesh([('x', 4), ('y', 2)])
|
||||
def test_pjit_gda_non_gda_inputs(self):
|
||||
input_shape = (8, 2)
|
||||
input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
@partial(pjit,
|
||||
@ -1353,7 +1354,8 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -1374,7 +1376,8 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
|
||||
@ -1400,7 +1403,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1439,7 +1442,8 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None)
|
||||
global_input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32
|
||||
).reshape(global_input_shape)
|
||||
|
||||
def cb(index):
|
||||
return global_input_data[index]
|
||||
@ -1500,7 +1504,7 @@ class GDAPjitTest(jtu.JaxTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P(None)
|
||||
global_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
f = pjit(lambda x: x, in_shardings=mesh_axes, out_shardings=mesh_axes)
|
||||
@ -1584,7 +1588,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest('GDA and Array cannot be together.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with parallel_functions_output_gda(True):
|
||||
with global_mesh:
|
||||
@ -1610,7 +1614,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
raise unittest.SkipTest('AutoSharding is not supported on stream_executor yet.')
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
@ -1634,7 +1638,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (4, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with ctx(True):
|
||||
with global_mesh:
|
||||
@ -1662,7 +1666,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_input_shape = (4, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with global_mesh:
|
||||
f = pjit(lambda x, y, z: (x, y, z), in_shardings=AUTO, out_shardings=AUTO)
|
||||
@ -1685,7 +1689,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
global_input_shape = (8, 4)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
in_resource = pspec
|
||||
|
||||
@ -1718,7 +1722,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names)
|
||||
global_input_shape = (8, 4)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
in_resource = NamedSharding(global_mesh, pspec)
|
||||
|
||||
@ -1748,7 +1752,7 @@ class AutoShardingPjitTest(jtu.JaxTestCase):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
@ -1835,7 +1839,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_data = np.arange(
|
||||
prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(lambda x: x,
|
||||
@ -1854,7 +1858,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
input_shape = (8, 2)
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
input_data = np.arange(
|
||||
prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
math.prod(input_shape), dtype=np.float32).reshape(input_shape)
|
||||
with jax_array(True):
|
||||
with global_mesh:
|
||||
f = pjit(
|
||||
@ -2026,7 +2030,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_globally_sharded_key_array_result_8x4_single_device(self):
|
||||
input_shape = (8, 4)
|
||||
seeds = jnp.arange(
|
||||
prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
||||
math.prod(input_shape), dtype=np.uint32).reshape(input_shape)
|
||||
|
||||
@pjit
|
||||
def make_keys(seeds):
|
||||
@ -2080,7 +2084,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
|
||||
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with m1:
|
||||
@ -2096,7 +2100,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
m2 = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
spec = P('x', 'y')
|
||||
|
||||
a1 = jnp.arange(prod(input_shape)).reshape(input_shape)
|
||||
a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
with jax_array(True):
|
||||
with m1:
|
||||
@ -2250,7 +2254,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
shape = (8, 2)
|
||||
mesh = jax.sharding.Mesh(mesh_devices, ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
# Explicitly put on the ordering of devices which does not match the mesh
|
||||
# ordering to make sure we reorder them in the constructor and the output
|
||||
@ -2271,7 +2275,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@jax_array(True)
|
||||
def test_not_xlacompatible_sharding_error(self):
|
||||
shape = (8, 2)
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
ts = TempSharding(jax.devices())
|
||||
arr = array.make_array_from_callback(
|
||||
shape, ts, lambda idx: inp_data[idx])
|
||||
@ -2333,7 +2337,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@jax_array(True)
|
||||
def test_pjit_uncommitted_array_and_committed_array(self):
|
||||
shape = (8, 2)
|
||||
uarr = jnp.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
carr, inp_data = create_array(shape, mesh, P('x', 'y'))
|
||||
with mesh:
|
||||
@ -2357,7 +2361,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_pjit_uncommitted_array_multi_devices(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp = np.arange(prod(shape), dtype=np.int32).reshape(shape)
|
||||
inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape)
|
||||
arr = array.ArrayImpl(
|
||||
core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)),
|
||||
[jax.device_put(inp, d) for d in mesh.devices.flat], committed=False)
|
||||
@ -2602,7 +2606,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(inp_data, s)
|
||||
out = pjit(lambda x: x)(arr)
|
||||
self.assertArraysEqual(out, inp_data)
|
||||
@ -2640,7 +2644,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_multi_device_pjit_mul(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y')))
|
||||
|
||||
@ -2655,7 +2659,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_single_device_pjit_cpp_dispatch(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None)
|
||||
with jtu.count_pjit_cache_miss() as count:
|
||||
@ -3379,7 +3383,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
def test_with_sharding_constraint_spmd_axis_name(self):
|
||||
mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl'))
|
||||
shape = (8, 4, 2, 2)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
def f(inp):
|
||||
return with_sharding_constraint(inp, P('data', None, None))
|
||||
@ -3723,7 +3727,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
def test_pjit_with_deleted_input_at_first_call(self, committed):
|
||||
shape = (8,)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
if committed:
|
||||
s = NamedSharding(mesh, P('x',))
|
||||
x = jax.device_put(inp_data, s)
|
||||
@ -3742,7 +3746,7 @@ class PJitErrorTest(jtu.JaxTestCase):
|
||||
def test_pjit_with_deleted_input_at_subsequent_call(self, committed):
|
||||
shape = (8,)
|
||||
mesh = jtu.create_global_mesh((1,), ('x',))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
inp_data = np.arange(math.prod(shape)).reshape(shape)
|
||||
if committed:
|
||||
s = NamedSharding(mesh, P('x',))
|
||||
x = jax.device_put(inp_data, s)
|
||||
|
@ -17,6 +17,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from random import shuffle
|
||||
from typing import Optional, cast
|
||||
@ -44,7 +45,7 @@ from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import device_array
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import prod, safe_map, safe_zip
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import array
|
||||
@ -119,7 +120,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,
|
||||
aval = ShapedArray(input_shape, dtype)
|
||||
|
||||
if input_data is None:
|
||||
input_data = np.arange(prod(input_shape)).reshape(input_shape)
|
||||
input_data = np.arange(math.prod(input_shape)).reshape(input_shape)
|
||||
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval, in_axes, sharded_dim_size)
|
||||
|
||||
@ -167,9 +168,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
msg = "device mesh shape {} not compatible with device count {}"
|
||||
raise SkipTest(msg.format(device_mesh_shape, device_count)) from err
|
||||
else:
|
||||
if device_count % prod(device_mesh_shape):
|
||||
if device_count % math.prod(device_mesh_shape):
|
||||
msg = "device mesh size {} does not divide available device count {}"
|
||||
raise SkipTest(msg.format(prod(device_mesh_shape), device_count))
|
||||
raise SkipTest(msg.format(math.prod(device_mesh_shape), device_count))
|
||||
else:
|
||||
return device_mesh_shape
|
||||
|
||||
@ -177,7 +178,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
ans = f(x)
|
||||
@ -193,7 +194,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompile(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = f(x)
|
||||
lowered = f.lower(x)
|
||||
compiled = lowered.compile()
|
||||
@ -210,7 +211,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompileInTreeMismatch(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f_exe = f.lower(x).compile()
|
||||
self.assertRaisesRegex(
|
||||
TypeError, "function compiled for .*, called with .*",
|
||||
@ -235,7 +236,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompileArgTypeMismatch(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=int).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=int).reshape(shape)
|
||||
x_f32 = x.astype(jnp.float32)
|
||||
x_i32 = x.astype(jnp.int32)
|
||||
f_exe = f.lower(x_f32).compile()
|
||||
@ -250,7 +251,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompileMultiArg(self):
|
||||
f = self.pmap(lambda x, y: x - lax.pmean(y, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = f(x, y)
|
||||
f_exe = f.lower(x, y).compile()
|
||||
ans = f_exe(x, y)
|
||||
@ -267,7 +268,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerAsText(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
@ -277,7 +278,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompilerIR(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
@ -289,14 +290,14 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# TODO(frostig): remove (deprecated)
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
|
||||
def testLowerCompileAsText(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsInstance(f.as_text(), (str, type(None)))
|
||||
|
||||
@ -304,7 +305,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCostAnalysis(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x)
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@ -312,7 +313,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompileCostAnalysis(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x).compile()
|
||||
f.cost_analysis() # doesn't raise
|
||||
|
||||
@ -320,21 +321,21 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testLowerCompileMemoryAnalysis(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x).compile()
|
||||
f.memory_analysis() # doesn't raise
|
||||
|
||||
def testLowerCompileExecutable(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
f = f.lower(x).compile()
|
||||
self.assertIsNotNone(f.runtime_executable())
|
||||
|
||||
def testLowerShapedArray(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x_shape = core.ShapedArray(x.shape, x.dtype)
|
||||
self.assertAllClose(f.lower(x_shape).compile()(x), f(x))
|
||||
|
||||
@ -342,7 +343,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.broadcast_to(np.mean(x, 0), x.shape)
|
||||
|
||||
ans = f(x)
|
||||
@ -352,7 +353,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x] * jax.device_count())
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -361,7 +362,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: lax.all_gather(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = (x % 2).astype(np.bool_)
|
||||
expected = np.array([x] * jax.device_count())
|
||||
ans = f(x)
|
||||
@ -371,7 +372,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: lax.all_gather(x, 'i', axis=-1), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x.T] * jax.device_count())
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -381,7 +382,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x] * device_count).reshape(device_count, -1)
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -392,7 +393,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 4, 3)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.array([x.transpose(1, 0, 2).reshape(4, -1)] * device_count)
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
@ -406,7 +407,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
device_count = jax.device_count()
|
||||
shape = (4, device_count, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
self.assertAllClose(vmap(f)(x), jnp.stack([f(xs) for xs in x], axis=0))
|
||||
|
||||
def testReduceScatter(self):
|
||||
@ -414,7 +415,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.sum(x, axis=0)
|
||||
ans = f(x)
|
||||
for i, actual in enumerate(ans):
|
||||
@ -425,7 +426,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 4 * device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.sum(x, axis=0)
|
||||
ans = f(x)
|
||||
scatter_len = len(expected) // device_count
|
||||
@ -444,7 +445,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (replicas, 4 * replicas)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
ans = f(x)
|
||||
|
||||
group_1_result = np.sum(x[0::2,:], axis=0)
|
||||
@ -504,7 +505,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4 * 2)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape).view(np.complex64)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
ans = f(x)
|
||||
@ -566,7 +567,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return np.repeat(np.sum(x, axis, keepdims=True), x.shape[axis], axis)
|
||||
|
||||
shape = (jax.device_count(), 1, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
|
||||
@ -591,7 +592,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(self.pmap(f, 'i'), 'j')
|
||||
|
||||
shape = mesh_shape + (4,)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
expected = x
|
||||
@ -605,7 +606,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
mesh_shape = (jax.device_count(),)
|
||||
shape = mesh_shape + (4,)
|
||||
x = np.array(3., dtype=np.float32)
|
||||
y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
f_expected = np.broadcast_to(x, mesh_shape)
|
||||
f_ans = f(x, y)
|
||||
@ -657,7 +658,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, axis_name='j', in_axes=(None, 0))
|
||||
|
||||
x = 3.
|
||||
y = np.arange(prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
|
||||
y = np.arange(math.prod(mesh_shape), dtype=np.float32).reshape(mesh_shape)
|
||||
expected = np.broadcast_to(x - np.sum(y, 1, keepdims=True), mesh_shape)
|
||||
|
||||
ans = f(x, y)
|
||||
@ -673,7 +674,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return jvp(jnp.ones_like(x))
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = np.cos(x)
|
||||
|
||||
ans = splitjvp(x)
|
||||
@ -687,7 +688,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x)
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
||||
expected = grad(lambda x: jnp.sum(f(x)))(x)
|
||||
@ -699,7 +700,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return lax.psum(x, axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
||||
|
||||
def testGradOfJvp(self):
|
||||
@ -714,7 +715,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
fun = lambda x: jnp.sum(jvp(jnp.sin, (x,), (jnp.ones_like(x),))[1])
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(splitjvp(x)))(x)
|
||||
expected = grad(fun)(x)
|
||||
@ -730,7 +731,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return tot * jnp.ones_like(x) # broadcast to map like pjit does
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = 4 + x
|
||||
ans = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
|
||||
expected = grad(lambda x, y: jnp.sum(g(x, y)))(x, y)
|
||||
@ -764,7 +765,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return grad(lambda w: jnp.sum(g(w)))(x)
|
||||
|
||||
shape = mesh_shape + (4,)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(test_fun(x)))(x)
|
||||
expected = grad(lambda x: jnp.sum(baseline_fun(x)))(x)
|
||||
@ -775,7 +776,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
# test that we can pass in and out ShardedDeviceArrays
|
||||
y = f(x)
|
||||
@ -840,7 +841,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < max(in_shape[:1] + out_shape[:1]):
|
||||
raise SkipTest("not enough devices")
|
||||
|
||||
x = np.arange(prod(in_shape)).reshape(in_shape)
|
||||
x = np.arange(math.prod(in_shape)).reshape(in_shape)
|
||||
sharded_x = self.pmap(lambda x: x)(x)
|
||||
self.assertAllClose(sharded_x.reshape(out_shape), x.reshape(out_shape),
|
||||
check_dtypes=False)
|
||||
@ -858,7 +859,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
shape = (num_pairs, 2, 4)
|
||||
else:
|
||||
shape = (device_count, 1, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
expected = sum_and_broadcast(sum_and_broadcast(x, 0), 1)
|
||||
@ -874,7 +875,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, 'i')
|
||||
|
||||
shape = (replicas, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected_psum = 2. * replicas // 2
|
||||
expected = x - expected_psum
|
||||
|
||||
@ -891,7 +892,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, 'i')
|
||||
|
||||
shape = (replicas, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
def sum_helper(a):
|
||||
return np.broadcast_to(a.sum(0, keepdims=True),
|
||||
(len(a), x.shape[1]))
|
||||
@ -913,7 +914,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, 'i')
|
||||
|
||||
shape = (replicas, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
def sum_helper(a):
|
||||
return np.broadcast_to(a.sum(0, keepdims=True),
|
||||
(replicas // 2, x.shape[1]))
|
||||
@ -938,7 +939,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, 'i')
|
||||
|
||||
shape = (replicas, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
|
||||
@ -963,7 +964,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, 'i')
|
||||
|
||||
shape = (replicas, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = f(x)
|
||||
|
||||
@ -999,7 +1000,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
axis_index_groups=axis_index_groups)
|
||||
|
||||
shape = (len(devices), 2 if axis_index_groups else jax.device_count())
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
jtu.check_grads(f, (x,), 2, ["fwd", "rev"], 1e-2, 1e-2, eps=1.)
|
||||
|
||||
def testNestedPmapReplicaGroups(self):
|
||||
@ -1014,7 +1015,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f3 = self.pmap(self.pmap(f, 'j'), 'i')
|
||||
|
||||
shape = (2, replicas // 2, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
def sum_helper_f1(a):
|
||||
return np.broadcast_to(a.sum(1, keepdims=True),
|
||||
(shape[0], shape[1] // 2, shape[2]))
|
||||
@ -1030,7 +1031,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
shape = (replicas // 2, 2, 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
def sum_helper_f3(a):
|
||||
return np.broadcast_to(a.sum(0, keepdims=True),
|
||||
(shape[0] // 2, shape[1], shape[2]))
|
||||
@ -1197,7 +1198,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.pmax(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.max(x, 0)
|
||||
|
||||
ans = f(x)
|
||||
@ -1207,7 +1208,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(lambda x: x - lax.pmin(x, 'i'), axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.min(x, 0)
|
||||
|
||||
ans = f(x)
|
||||
@ -1291,7 +1292,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
@ -1330,7 +1331,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
shuffle(devices)
|
||||
f = self.pmap(self.pmap(lambda x: 3), devices=devices)
|
||||
shape = (2, len(devices) // 2, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
|
||||
ans = f(x)
|
||||
# self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants
|
||||
@ -1352,7 +1353,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
raise SkipTest("error test doesn't apply with disable_jit")
|
||||
f = self.pmap(self.pmap(lambda x: 3))
|
||||
shape = (2, jax.device_count() // 2 + 1, 3)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
self.assertRaisesRegex(
|
||||
ValueError,
|
||||
(r"compiling computation that requires \d+ logical devices, "
|
||||
@ -1363,7 +1364,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# if jax.device_count() > 1:
|
||||
# f = pmap(pmap(lambda x: 3), devices=jax.devices()[:-1])
|
||||
# shape = (2, jax.device_count() // 2, 3)
|
||||
# x = jnp.arange(prod(shape)).reshape(shape)
|
||||
# x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
# self.assertRaisesRegex(
|
||||
# ValueError,
|
||||
# (r"compiling computation that requires \d+ replicas, "
|
||||
@ -1392,7 +1393,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
return g(x)
|
||||
|
||||
shape = (device_count, 1, 4)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
a, b, c = f(x)
|
||||
|
||||
self.assertEqual(a.shape, shape[:-1])
|
||||
@ -1525,7 +1526,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testPswapaxes(self):
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 3, device_count, 5)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
ans = self.pmap(lambda x: lax.pswapaxes(x, 'i', 1), axis_name='i')(x)
|
||||
expected = np.swapaxes(x, 0, 2)
|
||||
@ -1535,7 +1536,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def testGradOfPswapaxes(self):
|
||||
device_count = jax.device_count()
|
||||
shape = (device_count, 1, device_count)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
w = np.arange(device_count, dtype=np.float32)
|
||||
|
||||
@partial(self.pmap, axis_name='i')
|
||||
@ -1561,7 +1562,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
if device_count % 2 != 0:
|
||||
raise SkipTest('test requires an even number of devices')
|
||||
shape = (device_count, device_count // 2)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
|
||||
axis_index_groups = np.arange(device_count, dtype=np.int32)
|
||||
axis_index_groups = axis_index_groups.reshape((device_count // 2, 2)).T
|
||||
@ -1582,7 +1583,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
if device_count % 2 != 0:
|
||||
raise SkipTest('test requires an even number of devices')
|
||||
shape = (device_count, device_count // 2, 1)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
w = np.arange(device_count, dtype=np.float32)
|
||||
|
||||
axis_index_groups = np.arange(device_count, dtype=np.int32)
|
||||
@ -1610,7 +1611,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = lambda x: x - lax.psum(x, 'i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
|
||||
ans = jit(self.pmap(f, 'i'))(x)
|
||||
@ -1656,7 +1657,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = self.pmap(f, axis_name='i')
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
y = f(x)
|
||||
self.assertIsInstance(y, jax.Array)
|
||||
@ -2363,7 +2364,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i',
|
||||
devices=jax.devices())
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x - np.sum(x, 0)
|
||||
ans = f(x)
|
||||
self.assertAllClose(ans, expected)
|
||||
@ -2387,7 +2388,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def testNoDevicesError(self):
|
||||
f = pmap(lambda x: x - lax.psum(x, 'i'), axis_name='i', devices=[])
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "'devices' argument to pmap must be non-empty, or None."):
|
||||
f(x)
|
||||
@ -2508,7 +2509,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
return jnp.sin(x)
|
||||
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
|
||||
ans = grad(lambda x: jnp.sum(jnp.sin(x)))(x)
|
||||
expected = grad(lambda x: jnp.sum(f(x)))(x)
|
||||
@ -2519,7 +2520,7 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y())
|
||||
shape = (jax.device_count(), 4)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = lambda: 3.
|
||||
|
||||
ans = f(x, y)
|
||||
@ -2531,9 +2532,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y)
|
||||
xshape = (2, jax.device_count(), 4)
|
||||
x = np.arange(prod(xshape)).reshape(xshape)
|
||||
x = np.arange(math.prod(xshape)).reshape(xshape)
|
||||
yshape = (2, 4, jax.device_count())
|
||||
y = np.arange(prod(yshape)).reshape(yshape)
|
||||
y = np.arange(math.prod(yshape)).reshape(yshape)
|
||||
|
||||
self.assertAllClose(f(x, y),
|
||||
jnp.sin(x.transpose((1, 0, 2)) + y.transpose((2, 0, 1))))
|
||||
@ -2544,11 +2545,11 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
fp = pmap(f, in_axes=(1, 2, None))
|
||||
fv = vmap(f, in_axes=(1, 2, None))
|
||||
xshape = (5, jax.device_count(), 7)
|
||||
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
yshape = (5, 7, jax.device_count())
|
||||
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
zshape = (5, 7)
|
||||
z = np.arange(prod(zshape), dtype=np.float32).reshape(zshape)
|
||||
z = np.arange(math.prod(zshape), dtype=np.float32).reshape(zshape)
|
||||
|
||||
dx, dy, dz = jax.grad(lambda args: fp(*args).sum())((x, y, z))
|
||||
assert dx.shape == xshape
|
||||
@ -2563,9 +2564,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
def f(x, y):
|
||||
return jnp.sin(x + y), y * 2
|
||||
xshape = (2, jax.device_count(), 4)
|
||||
x = np.arange(prod(xshape)).reshape(xshape)
|
||||
x = np.arange(math.prod(xshape)).reshape(xshape)
|
||||
yshape = (2, 4)
|
||||
y = np.arange(prod(yshape)).reshape(yshape)
|
||||
y = np.arange(math.prod(yshape)).reshape(yshape)
|
||||
|
||||
self.assertAllClose(f(x, y),
|
||||
(jnp.sin(x.transpose((1, 0, 2)) + y).transpose((1, 2, 0)), y * 2))
|
||||
@ -2606,9 +2607,9 @@ class PmapWithDevicesTest(jtu.JaxTestCase):
|
||||
return f
|
||||
|
||||
xshape = (5, 7)
|
||||
x = np.arange(prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
x = np.arange(math.prod(xshape), dtype=np.float32).reshape(xshape)
|
||||
yshape = (5, jax.device_count(), 7)
|
||||
y = np.arange(prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
y = np.arange(math.prod(yshape), dtype=np.float32).reshape(yshape)
|
||||
self.assertAllClose(jax.grad(mk_case(pmap))(x, y),
|
||||
jax.grad(mk_case(vmap))(x, y))
|
||||
|
||||
@ -2624,7 +2625,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < shape[0]:
|
||||
raise SkipTest(f"requires {shape[0]} devices")
|
||||
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
sharded_x = pmap(lambda x: x)(x)
|
||||
|
||||
num_threads = 10
|
||||
@ -2650,7 +2651,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < shape[0]:
|
||||
raise SkipTest(f"requires {shape[0]} devices")
|
||||
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
x = jnp.arange(math.prod(shape)).reshape(shape)
|
||||
sharded_x = pmap(lambda x: x)(x)
|
||||
self.assertIsNone(sharded_x._npy_value)
|
||||
|
||||
@ -2941,7 +2942,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
nshards = len(indices)
|
||||
if jax.device_count() < nshards:
|
||||
raise SkipTest
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
x = np.arange(math.prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
bufs = pxla.shard_args(jax.devices()[:nshards], [indices], [arg])
|
||||
self.assertEqual(len(bufs), 1)
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -24,7 +25,6 @@ from jax import grad
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax.scipy import ndimage as lsp_ndimage
|
||||
from jax._src.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -80,7 +80,7 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
mode, cval, impl, round_, rng_factory):
|
||||
|
||||
def args_maker():
|
||||
x = np.arange(prod(shape), dtype=dtype).reshape(shape)
|
||||
x = np.arange(math.prod(shape), dtype=dtype).reshape(shape)
|
||||
coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape]
|
||||
if round_:
|
||||
coords = [c.round().astype(int) for c in coords]
|
||||
|
@ -11,8 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import math
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from typing import (Any, Sequence, Set, Iterable, Iterator, NamedTuple,
|
||||
@ -31,7 +33,7 @@ from jax.sharding import PartitionSpec as P
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_zip, safe_map, prod, partition_list, merge_lists
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@ -552,13 +554,13 @@ def shmap_reference(
|
||||
|
||||
def make_indexer(mesh: Mesh, spec: P, x: Any
|
||||
) -> Callable[[Tuple[int, ...]], Tuple[slice, ...]]:
|
||||
block_shape = [d // prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
block_shape = [d // math.prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
for d, elt in zip(x.shape, spec)]
|
||||
def indexer(idx):
|
||||
starts = [0 if el is None else
|
||||
idx[list(mesh.shape).index(el)] if type(el) is not tuple else
|
||||
sum(idx[list(mesh.shape).index(el[i])]
|
||||
* prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
|
||||
* math.prod(mesh.shape[e] for e in el[i+1:]) for i in range(len(el)))
|
||||
for el in spec]
|
||||
return tuple(slice(start * size, (start + 1) * size)
|
||||
for start, size in zip(starts, block_shape))
|
||||
@ -647,7 +649,7 @@ def make_in_spec(mesh: Mesh, in_type_base: ShapeDtypeDuck) -> Chooser:
|
||||
return new_type, partition_spec
|
||||
|
||||
def dilate(mesh: Mesh, spec: P, shape: ShapeDtypeDuck) -> ShapeDtypeDuck:
|
||||
new_shape = tuple(d * prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
new_shape = tuple(d * math.prod(mesh.shape[ax] for ax in (elt or ()))
|
||||
for d, elt in zip(shape.shape, spec))
|
||||
return jax.ShapeDtypeStruct(new_shape, shape.dtype)
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import functools
|
||||
import itertools as it
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from itertools import product, permutations
|
||||
@ -45,7 +46,7 @@ from jax._src import config as jax_config
|
||||
from jax._src.nn import initializers as nn_initializers
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import unzip2, prod, safe_zip
|
||||
from jax._src.util import unzip2, safe_zip
|
||||
from jax._src.lax import parallel as lax_parallel
|
||||
from jax._src.lax.parallel import pgather
|
||||
from jax.interpreters import batching, pxla
|
||||
@ -80,7 +81,7 @@ def tearDownModule():
|
||||
def create_array(global_shape, global_mesh, mesh_axes, global_data=None):
|
||||
if global_data is None:
|
||||
global_data = np.arange(
|
||||
prod(global_shape), dtype=np.float32).reshape(global_shape)
|
||||
math.prod(global_shape), dtype=np.float32).reshape(global_shape)
|
||||
|
||||
sharding = NamedSharding(global_mesh, mesh_axes)
|
||||
|
||||
@ -1118,7 +1119,7 @@ class XMapGDATest(XMapTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1146,7 +1147,7 @@ class XMapGDATest(XMapTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1185,7 +1186,7 @@ class XMapGDATest(XMapTestCase):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1224,7 +1225,7 @@ class XMapGDATest(XMapTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
def cb(index):
|
||||
return input_data[index]
|
||||
|
||||
@ -1248,7 +1249,7 @@ class XMapGDATest(XMapTestCase):
|
||||
global_input_shape = (8, 2)
|
||||
mesh_axes = P('x', 'y')
|
||||
input_data = np.arange(
|
||||
prod(global_input_shape)).reshape(global_input_shape)
|
||||
math.prod(global_input_shape)).reshape(global_input_shape)
|
||||
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
|
||||
global_input_shape, global_mesh, mesh_axes, lambda idx: input_data[idx])
|
||||
with jax_config.parallel_functions_output_gda(True):
|
||||
|
Loading…
x
Reference in New Issue
Block a user