mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Cleanup: avoid jnp.prod & np.prod on array shapes (#4086)
This commit is contained in:
parent
decd760020
commit
29aa9bfc8f
@ -24,6 +24,7 @@ import jax
|
||||
from jax import numpy as jnp
|
||||
from jax import pmap
|
||||
from jax.config import config
|
||||
from jax.util import prod
|
||||
|
||||
from benchmarks import benchmark
|
||||
|
||||
@ -118,7 +119,7 @@ def sharded_device_array_indexing_benchmark():
|
||||
nshards = min(8, jax.local_device_count())
|
||||
shape = (nshards, 8, 8)
|
||||
def benchmark_fn():
|
||||
arr = pmap(lambda x: x)(jnp.arange(jnp.prod(shape)).reshape(shape))
|
||||
arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
|
||||
indices = indices_fn()
|
||||
for idx in indices:
|
||||
arr[idx]
|
||||
|
@ -1724,7 +1724,7 @@ class ShapeDtypeStruct(object):
|
||||
self.shape = shape
|
||||
self.dtype = np.dtype(dtype)
|
||||
|
||||
size = property(lambda self: np.prod(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -26,6 +26,7 @@ import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import ops
|
||||
from jax import random
|
||||
from jax.util import prod
|
||||
|
||||
def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
|
||||
def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
|
||||
@ -41,7 +42,7 @@ def normal(stddev=1e-2, dtype=jnp.float32):
|
||||
return init
|
||||
|
||||
def _compute_fans(shape, in_axis=-2, out_axis=-1):
|
||||
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
receptive_field_size = prod(shape) / shape[in_axis] / shape[out_axis]
|
||||
fan_in = shape[in_axis] * receptive_field_size
|
||||
fan_out = shape[out_axis] * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
@ -85,7 +86,7 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
|
||||
def init(key, shape, dtype=dtype):
|
||||
if len(shape) < 2:
|
||||
raise ValueError("orthogonal initializer requires at least a 2D shape")
|
||||
n_rows, n_cols = np.prod(shape) // shape[column_axis], shape[column_axis]
|
||||
n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
|
||||
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
|
||||
A = random.normal(key, matrix_shape, dtype)
|
||||
Q, R = jnp.linalg.qr(A)
|
||||
|
@ -296,7 +296,7 @@ def _random_bits(key, bit_width, shape):
|
||||
raise TypeError("_random_bits got invalid prng key.")
|
||||
if bit_width not in (8, 16, 32, 64):
|
||||
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
|
||||
size = np.prod(shape)
|
||||
size = prod(shape)
|
||||
max_count = int(np.ceil(bit_width * size / 32))
|
||||
if max_count >= jnp.iinfo(np.uint32).max:
|
||||
# TODO(mattjj): just split the key here
|
||||
@ -560,7 +560,7 @@ def choice(key, a, shape=(), replace=True, p=None):
|
||||
if a.ndim not in [0, 1]:
|
||||
raise ValueError("a must be an integer or 1-dimensional")
|
||||
n_inputs = int(a) if a.ndim == 0 else len(a)
|
||||
n_draws = np.prod(shape).astype(int)
|
||||
n_draws = prod(shape)
|
||||
if n_draws == 0:
|
||||
return jnp.zeros(shape, dtype=a.dtype)
|
||||
if n_inputs <= 0:
|
||||
|
@ -33,7 +33,7 @@ from . import core
|
||||
from . import dtypes as _dtypes
|
||||
from . import lax
|
||||
from .config import flags, bool_env
|
||||
from .util import partial
|
||||
from .util import partial, prod
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from .lib import xla_bridge
|
||||
from .interpreters import xla
|
||||
@ -674,7 +674,7 @@ def rand_int(rng, low=0, high=None):
|
||||
|
||||
def rand_unique_int(rng, high=None):
|
||||
def fn(shape, dtype):
|
||||
return rng.choice(np.arange(high or np.prod(shape), dtype=dtype),
|
||||
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
|
||||
size=shape, replace=False)
|
||||
return fn
|
||||
|
||||
|
2
jax/third_party/numpy/linalg.py
vendored
2
jax/third_party/numpy/linalg.py
vendored
@ -7,7 +7,7 @@ from jax.numpy._util import _wraps
|
||||
|
||||
def _isEmpty2d(arr):
|
||||
# check size first for efficiency
|
||||
return arr.size == 0 and jnp.product(arr.shape[-2:]) == 0
|
||||
return arr.size == 0 and np.product(arr.shape[-2:]) == 0
|
||||
|
||||
|
||||
def _assertNoEmpty2d(*arrays):
|
||||
|
@ -35,6 +35,7 @@ from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import host_callback as hcb
|
||||
from jax.lib import xla_bridge
|
||||
from jax.util import prod
|
||||
import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -593,7 +594,7 @@ where: 10
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if dtype in (jnp.int16,):
|
||||
raise SkipTest(f"transfering {dtype} not supported on TPU")
|
||||
args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
|
||||
args = [jnp.arange(prod(shape), dtype=dtype).reshape(shape)]
|
||||
if nr_args > 1:
|
||||
args = args * nr_args
|
||||
jit_fun1 = api.jit(lambda xs: hcb.id_print(
|
||||
|
@ -30,6 +30,7 @@ from jax import dtypes
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax.test_util import check_grads
|
||||
from jax.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -799,7 +800,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -817,7 +818,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
for k in [1, 3]
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTopKGrad(self, shape, dtype, k, rng_factory):
|
||||
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
|
||||
flat_values = np.arange(prod(shape), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
fun = lambda vs: lax.top_k(vs, k=k)[0]
|
||||
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
|
||||
|
@ -44,6 +44,7 @@ from jax import dtypes
|
||||
from jax import tree_util
|
||||
from jax.interpreters import partial_eval, xla
|
||||
from jax.test_util import check_grads
|
||||
from jax.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -1315,7 +1316,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if shape in scalar_shapes or len(shape) == 0:
|
||||
cond_shape = (0,)
|
||||
elif axis is None:
|
||||
cond_shape = (np.prod(shape),)
|
||||
cond_shape = (prod(shape),)
|
||||
else:
|
||||
cond_shape = (shape[axis],)
|
||||
|
||||
@ -1353,7 +1354,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if shape in scalar_shapes or len(shape) == 0:
|
||||
cond_shape = (0,)
|
||||
elif axis is None:
|
||||
cond_shape = (np.prod(shape),)
|
||||
cond_shape = (prod(shape),)
|
||||
else:
|
||||
cond_shape = (shape[axis],)
|
||||
|
||||
|
@ -32,6 +32,7 @@ from jax import test_util as jtu
|
||||
from jax import lax_reference
|
||||
from jax.test_util import check_grads
|
||||
import jax.util
|
||||
from jax.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -1448,7 +1449,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -1498,7 +1499,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
# too, since we don't guarantee the same ordering of values with equal keys.
|
||||
# To avoid that case, we generate unique keys (globally in the key array).
|
||||
def args_maker():
|
||||
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
|
||||
flat_keys = np.arange(prod(shape), dtype=key_dtype)
|
||||
keys = self.rng().permutation(flat_keys).reshape(shape)
|
||||
values = rng(shape, val_dtype)
|
||||
return keys, values
|
||||
@ -1517,7 +1518,7 @@ class LaxTest(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]))
|
||||
def testTopK(self, shape, dtype, k, rng_factory):
|
||||
def args_maker():
|
||||
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
|
||||
flat_values = np.arange(prod(shape), dtype=dtype)
|
||||
values = self.rng().permutation(flat_values).reshape(shape)
|
||||
return [values]
|
||||
def reference_top_k(x):
|
||||
|
@ -1025,7 +1025,7 @@ class PmapTest(jtu.JaxTestCase):
|
||||
# Manually construct a ShardedDeviceArray with the wrong sharding for the
|
||||
# subsequent pmap
|
||||
shard_shape = (3,2)
|
||||
shard = jnp.arange(jnp.prod(jnp.array(shard_shape))).reshape(shard_shape)
|
||||
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
|
||||
bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
|
||||
aval = ShapedArray((6,4), shard.dtype)
|
||||
sharding_spec = pxla.ShardingSpec(
|
||||
@ -1620,7 +1620,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
|
||||
if jax.device_count() < shape[0]:
|
||||
raise SkipTest(f"requires {shape[0]} devices")
|
||||
|
||||
x = jnp.arange(jnp.prod(jnp.array(shape))).reshape(shape)
|
||||
x = jnp.arange(prod(shape)).reshape(shape)
|
||||
sharded_x = pmap(lambda x: x)(x)
|
||||
|
||||
num_threads = 10
|
||||
@ -1816,7 +1816,7 @@ class ShardArgsTest(jtu.JaxTestCase):
|
||||
nshards = len(indices)
|
||||
if jax.device_count() < nshards:
|
||||
raise SkipTest
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
arg = make_arg(x)
|
||||
bufs = pxla.shard_args(jax.devices()[:nshards],
|
||||
[indices], [arg])
|
||||
|
@ -25,6 +25,7 @@ from jax import grad
|
||||
from jax import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax.scipy import ndimage as lsp_ndimage
|
||||
from jax.util import prod
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -83,7 +84,7 @@ class NdimageTest(jtu.JaxTestCase):
|
||||
mode, cval, impl, round_, rng_factory):
|
||||
|
||||
def args_maker():
|
||||
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
|
||||
x = np.arange(prod(shape), dtype=dtype).reshape(shape)
|
||||
coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape]
|
||||
if round_:
|
||||
coords = [c.round().astype(int) for c in coords]
|
||||
|
@ -31,6 +31,7 @@ from jax import tree_util
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import sharded_jit, with_sharding_constraint
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
from jax.util import prod
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.config import config
|
||||
@ -53,7 +54,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
return x + y
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
actual = f(x, x + 1)
|
||||
expected = x + (x + 1)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -72,7 +73,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
return a1 + a2 + b + c1 + c2 + c3
|
||||
|
||||
def _make_arg(*shape):
|
||||
return np.arange(np.prod(shape)).reshape(shape)
|
||||
return np.arange(prod(shape)).reshape(shape)
|
||||
|
||||
a = (_make_arg(4, 4), 1)
|
||||
b = _make_arg(4, 4)
|
||||
@ -102,7 +103,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
return x + 1, ((x + 2, x + 3), x + 4)
|
||||
|
||||
shape = (4, 4)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
in_parts = (P(2, 1),)
|
||||
out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))
|
||||
|
||||
@ -134,7 +135,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
return y * 2
|
||||
|
||||
shape = (8, 8)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
expected = (x + 1) * 2
|
||||
|
||||
# Matching sharded_jit partitions
|
||||
@ -173,7 +174,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
lambda i: with_sharding_constraint(i + 1., P(2, 1)),
|
||||
x)
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
expected = x + 10.
|
||||
actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -195,7 +196,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
return vjp_f(p)
|
||||
|
||||
shape = (4, 4)
|
||||
x = jnp.arange(jnp.prod(shape), dtype=jnp.float32).reshape(shape)
|
||||
x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape)
|
||||
actual = f(x)
|
||||
expected = expected_f(x)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
@ -222,7 +223,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
(y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
|
||||
return x @ y.T + z
|
||||
|
||||
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
y = x + 1
|
||||
shard_size = shape[0] // jax.local_device_count()
|
||||
y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
|
||||
@ -303,7 +304,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
if num_shards > jax.local_device_count():
|
||||
raise SkipTest("requires %d devices" % num_shards)
|
||||
|
||||
x = np.arange(np.prod(shape, dtype=dtype)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
y = x + 1
|
||||
result = pmap(
|
||||
sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
|
||||
@ -395,7 +396,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
return a1 + a2 + b + c1 + c2 + c3
|
||||
|
||||
def _make_arg(*shape):
|
||||
return np.arange(np.prod(shape)).reshape(shape)
|
||||
return np.arange(prod(shape)).reshape(shape)
|
||||
|
||||
a = (_make_arg(2, 4, 4), _make_arg(2))
|
||||
b = _make_arg(2, 4, 4)
|
||||
@ -419,7 +420,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
return x + 1, ((x + 2, x + 3), x + 4)
|
||||
|
||||
shape = (2, 4, 4)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
in_parts = (P(2, 1),)
|
||||
out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))
|
||||
|
||||
@ -438,7 +439,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
return jnp.sum(args)
|
||||
|
||||
shape = (2, 4, 4)
|
||||
args = [np.arange(np.prod(shape)).reshape(shape)] * num_args
|
||||
args = [np.arange(prod(shape)).reshape(shape)] * num_args
|
||||
in_partitions = (P(2, 1),) * num_args
|
||||
out_partitions = None
|
||||
result = pmap(sharded_jit(
|
||||
@ -463,7 +464,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
return jnp.dot(x, x) * 2
|
||||
|
||||
shape = (2, 8, 8)
|
||||
x = np.arange(np.prod(shape)).reshape(shape)
|
||||
x = np.arange(prod(shape)).reshape(shape)
|
||||
result = pmap(f)(x)
|
||||
expected = pmap(expected_f)(x)
|
||||
|
||||
@ -477,7 +478,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
in_partitions = (P(2, 1), None, None)
|
||||
out_partitions = P(2, 1)
|
||||
in_axes = (None, None, 0)
|
||||
x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
|
||||
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
|
||||
dummy = np.arange(replicas, dtype=np.float32) + 1
|
||||
num_shards = replicas * np.prod(in_partitions[0])
|
||||
if num_shards > jax.local_device_count():
|
||||
|
Loading…
x
Reference in New Issue
Block a user