Cleanup: avoid jnp.prod & np.prod on array shapes (#4086)

This commit is contained in:
Jake Vanderplas 2020-08-18 10:17:38 -07:00 committed by GitHub
parent decd760020
commit 29aa9bfc8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 42 additions and 34 deletions

View File

@ -24,6 +24,7 @@ import jax
from jax import numpy as jnp
from jax import pmap
from jax.config import config
from jax.util import prod
from benchmarks import benchmark
@ -118,7 +119,7 @@ def sharded_device_array_indexing_benchmark():
nshards = min(8, jax.local_device_count())
shape = (nshards, 8, 8)
def benchmark_fn():
arr = pmap(lambda x: x)(jnp.arange(jnp.prod(shape)).reshape(shape))
arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape))
indices = indices_fn()
for idx in indices:
arr[idx]

View File

@ -1724,7 +1724,7 @@ class ShapeDtypeStruct(object):
self.shape = shape
self.dtype = np.dtype(dtype)
size = property(lambda self: np.prod(self.shape))
size = property(lambda self: prod(self.shape))
ndim = property(lambda self: len(self.shape))
def __len__(self):

View File

@ -26,6 +26,7 @@ import jax.numpy as jnp
from jax import lax
from jax import ops
from jax import random
from jax.util import prod
def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
@ -41,7 +42,7 @@ def normal(stddev=1e-2, dtype=jnp.float32):
return init
def _compute_fans(shape, in_axis=-2, out_axis=-1):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
receptive_field_size = prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
@ -85,7 +86,7 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
def init(key, shape, dtype=dtype):
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
n_rows, n_cols = np.prod(shape) // shape[column_axis], shape[column_axis]
n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)
A = random.normal(key, matrix_shape, dtype)
Q, R = jnp.linalg.qr(A)

View File

@ -296,7 +296,7 @@ def _random_bits(key, bit_width, shape):
raise TypeError("_random_bits got invalid prng key.")
if bit_width not in (8, 16, 32, 64):
raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
size = np.prod(shape)
size = prod(shape)
max_count = int(np.ceil(bit_width * size / 32))
if max_count >= jnp.iinfo(np.uint32).max:
# TODO(mattjj): just split the key here
@ -560,7 +560,7 @@ def choice(key, a, shape=(), replace=True, p=None):
if a.ndim not in [0, 1]:
raise ValueError("a must be an integer or 1-dimensional")
n_inputs = int(a) if a.ndim == 0 else len(a)
n_draws = np.prod(shape).astype(int)
n_draws = prod(shape)
if n_draws == 0:
return jnp.zeros(shape, dtype=a.dtype)
if n_inputs <= 0:

View File

@ -33,7 +33,7 @@ from . import core
from . import dtypes as _dtypes
from . import lax
from .config import flags, bool_env
from .util import partial
from .util import partial, prod
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from .lib import xla_bridge
from .interpreters import xla
@ -674,7 +674,7 @@ def rand_int(rng, low=0, high=None):
def rand_unique_int(rng, high=None):
def fn(shape, dtype):
return rng.choice(np.arange(high or np.prod(shape), dtype=dtype),
return rng.choice(np.arange(high or prod(shape), dtype=dtype),
size=shape, replace=False)
return fn

View File

@ -7,7 +7,7 @@ from jax.numpy._util import _wraps
def _isEmpty2d(arr):
# check size first for efficiency
return arr.size == 0 and jnp.product(arr.shape[-2:]) == 0
return arr.size == 0 and np.product(arr.shape[-2:]) == 0
def _assertNoEmpty2d(*arrays):

View File

@ -35,6 +35,7 @@ from jax import test_util as jtu
from jax.config import config
from jax.experimental import host_callback as hcb
from jax.lib import xla_bridge
from jax.util import prod
import numpy as np
config.parse_flags_with_absl()
@ -593,7 +594,7 @@ where: 10
if jtu.device_under_test() == "tpu":
if dtype in (jnp.int16,):
raise SkipTest(f"transfering {dtype} not supported on TPU")
args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)]
args = [jnp.arange(prod(shape), dtype=dtype).reshape(shape)]
if nr_args > 1:
args = args * nr_args
jit_fun1 = api.jit(lambda xs: hcb.id_print(

View File

@ -30,6 +30,7 @@ from jax import dtypes
from jax import lax
from jax import test_util as jtu
from jax.test_util import check_grads
from jax.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -799,7 +800,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
flat_keys = np.arange(prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -817,7 +818,7 @@ class LaxAutodiffTest(jtu.JaxTestCase):
for k in [1, 3]
for rng_factory in [jtu.rand_default]))
def testTopKGrad(self, shape, dtype, k, rng_factory):
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
flat_values = np.arange(prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
fun = lambda vs: lax.top_k(vs, k=k)[0]
check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)

View File

@ -44,6 +44,7 @@ from jax import dtypes
from jax import tree_util
from jax.interpreters import partial_eval, xla
from jax.test_util import check_grads
from jax.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -1315,7 +1316,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (np.prod(shape),)
cond_shape = (prod(shape),)
else:
cond_shape = (shape[axis],)
@ -1353,7 +1354,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if shape in scalar_shapes or len(shape) == 0:
cond_shape = (0,)
elif axis is None:
cond_shape = (np.prod(shape),)
cond_shape = (prod(shape),)
else:
cond_shape = (shape[axis],)

View File

@ -32,6 +32,7 @@ from jax import test_util as jtu
from jax import lax_reference
from jax.test_util import check_grads
import jax.util
from jax.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -1448,7 +1449,7 @@ class LaxTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
flat_keys = np.arange(prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -1498,7 +1499,7 @@ class LaxTest(jtu.JaxTestCase):
# too, since we don't guarantee the same ordering of values with equal keys.
# To avoid that case, we generate unique keys (globally in the key array).
def args_maker():
flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype)
flat_keys = np.arange(prod(shape), dtype=key_dtype)
keys = self.rng().permutation(flat_keys).reshape(shape)
values = rng(shape, val_dtype)
return keys, values
@ -1517,7 +1518,7 @@ class LaxTest(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]))
def testTopK(self, shape, dtype, k, rng_factory):
def args_maker():
flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype)
flat_values = np.arange(prod(shape), dtype=dtype)
values = self.rng().permutation(flat_values).reshape(shape)
return [values]
def reference_top_k(x):

View File

@ -1025,7 +1025,7 @@ class PmapTest(jtu.JaxTestCase):
# Manually construct a ShardedDeviceArray with the wrong sharding for the
# subsequent pmap
shard_shape = (3,2)
shard = jnp.arange(jnp.prod(jnp.array(shard_shape))).reshape(shard_shape)
shard = jnp.arange(prod(shard_shape)).reshape(shard_shape)
bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]]
aval = ShapedArray((6,4), shard.dtype)
sharding_spec = pxla.ShardingSpec(
@ -1620,7 +1620,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase):
if jax.device_count() < shape[0]:
raise SkipTest(f"requires {shape[0]} devices")
x = jnp.arange(jnp.prod(jnp.array(shape))).reshape(shape)
x = jnp.arange(prod(shape)).reshape(shape)
sharded_x = pmap(lambda x: x)(x)
num_threads = 10
@ -1816,7 +1816,7 @@ class ShardArgsTest(jtu.JaxTestCase):
nshards = len(indices)
if jax.device_count() < nshards:
raise SkipTest
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
arg = make_arg(x)
bufs = pxla.shard_args(jax.devices()[:nshards],
[indices], [arg])

View File

@ -25,6 +25,7 @@ from jax import grad
from jax import test_util as jtu
from jax import dtypes
from jax.scipy import ndimage as lsp_ndimage
from jax.util import prod
from jax.config import config
config.parse_flags_with_absl()
@ -83,7 +84,7 @@ class NdimageTest(jtu.JaxTestCase):
mode, cval, impl, round_, rng_factory):
def args_maker():
x = np.arange(np.prod(shape), dtype=dtype).reshape(shape)
x = np.arange(prod(shape), dtype=dtype).reshape(shape)
coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape]
if round_:
coords = [c.round().astype(int) for c in coords]

View File

@ -31,6 +31,7 @@ from jax import tree_util
from jax.interpreters import pxla
from jax.interpreters.sharded_jit import sharded_jit, with_sharding_constraint
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.util import prod
import jax.numpy as jnp
from jax.config import config
@ -53,7 +54,7 @@ class ShardedJitTest(jtu.JaxTestCase):
return x + y
shape = (8, 8)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
actual = f(x, x + 1)
expected = x + (x + 1)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -72,7 +73,7 @@ class ShardedJitTest(jtu.JaxTestCase):
return a1 + a2 + b + c1 + c2 + c3
def _make_arg(*shape):
return np.arange(np.prod(shape)).reshape(shape)
return np.arange(prod(shape)).reshape(shape)
a = (_make_arg(4, 4), 1)
b = _make_arg(4, 4)
@ -102,7 +103,7 @@ class ShardedJitTest(jtu.JaxTestCase):
return x + 1, ((x + 2, x + 3), x + 4)
shape = (4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
in_parts = (P(2, 1),)
out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))
@ -134,7 +135,7 @@ class ShardedJitTest(jtu.JaxTestCase):
return y * 2
shape = (8, 8)
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
expected = (x + 1) * 2
# Matching sharded_jit partitions
@ -173,7 +174,7 @@ class ShardedJitTest(jtu.JaxTestCase):
lambda i: with_sharding_constraint(i + 1., P(2, 1)),
x)
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
expected = x + 10.
actual = sharded_jit(f, in_parts=None, out_parts=None)(x)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -195,7 +196,7 @@ class ShardedJitTest(jtu.JaxTestCase):
return vjp_f(p)
shape = (4, 4)
x = jnp.arange(jnp.prod(shape), dtype=jnp.float32).reshape(shape)
x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape)
actual = f(x)
expected = expected_f(x)
self.assertAllClose(actual, expected, check_dtypes=False)
@ -222,7 +223,7 @@ class ShardedJitTest(jtu.JaxTestCase):
(y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts)
return x @ y.T + z
x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = np.arange(prod(shape), dtype=np.float32).reshape(shape)
y = x + 1
shard_size = shape[0] // jax.local_device_count()
y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)]
@ -303,7 +304,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
if num_shards > jax.local_device_count():
raise SkipTest("requires %d devices" % num_shards)
x = np.arange(np.prod(shape, dtype=dtype)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
y = x + 1
result = pmap(
sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y)
@ -395,7 +396,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
return a1 + a2 + b + c1 + c2 + c3
def _make_arg(*shape):
return np.arange(np.prod(shape)).reshape(shape)
return np.arange(prod(shape)).reshape(shape)
a = (_make_arg(2, 4, 4), _make_arg(2))
b = _make_arg(2, 4, 4)
@ -419,7 +420,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
return x + 1, ((x + 2, x + 3), x + 4)
shape = (2, 4, 4)
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
in_parts = (P(2, 1),)
out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1)))
@ -438,7 +439,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
return jnp.sum(args)
shape = (2, 4, 4)
args = [np.arange(np.prod(shape)).reshape(shape)] * num_args
args = [np.arange(prod(shape)).reshape(shape)] * num_args
in_partitions = (P(2, 1),) * num_args
out_partitions = None
result = pmap(sharded_jit(
@ -463,7 +464,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
return jnp.dot(x, x) * 2
shape = (2, 8, 8)
x = np.arange(np.prod(shape)).reshape(shape)
x = np.arange(prod(shape)).reshape(shape)
result = pmap(f)(x)
expected = pmap(expected_f)(x)
@ -477,7 +478,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase):
in_partitions = (P(2, 1), None, None)
out_partitions = P(2, 1)
in_axes = (None, None, 0)
x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape)
dummy = np.arange(replicas, dtype=np.float32) + 1
num_shards = replicas * np.prod(in_partitions[0])
if num_shards > jax.local_device_count():