From 29aa9bfc8f5dbdf635e877395dc2082cefc8bc87 Mon Sep 17 00:00:00 2001 From: Jake Vanderplas Date: Tue, 18 Aug 2020 10:17:38 -0700 Subject: [PATCH] Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) --- benchmarks/pmap_benchmark.py | 3 ++- jax/api.py | 2 +- jax/nn/initializers.py | 5 +++-- jax/random.py | 4 ++-- jax/test_util.py | 4 ++-- jax/third_party/numpy/linalg.py | 2 +- tests/host_callback_test.py | 3 ++- tests/lax_autodiff_test.py | 5 +++-- tests/lax_numpy_test.py | 5 +++-- tests/lax_test.py | 7 ++++--- tests/pmap_test.py | 6 +++--- tests/scipy_ndimage_test.py | 3 ++- tests/sharded_jit_test.py | 27 ++++++++++++++------------- 13 files changed, 42 insertions(+), 34 deletions(-) diff --git a/benchmarks/pmap_benchmark.py b/benchmarks/pmap_benchmark.py index 96ee86b42..5a1dfc583 100644 --- a/benchmarks/pmap_benchmark.py +++ b/benchmarks/pmap_benchmark.py @@ -24,6 +24,7 @@ import jax from jax import numpy as jnp from jax import pmap from jax.config import config +from jax.util import prod from benchmarks import benchmark @@ -118,7 +119,7 @@ def sharded_device_array_indexing_benchmark(): nshards = min(8, jax.local_device_count()) shape = (nshards, 8, 8) def benchmark_fn(): - arr = pmap(lambda x: x)(jnp.arange(jnp.prod(shape)).reshape(shape)) + arr = pmap(lambda x: x)(jnp.arange(prod(shape)).reshape(shape)) indices = indices_fn() for idx in indices: arr[idx] diff --git a/jax/api.py b/jax/api.py index 954d01a65..1a8da12ef 100644 --- a/jax/api.py +++ b/jax/api.py @@ -1724,7 +1724,7 @@ class ShapeDtypeStruct(object): self.shape = shape self.dtype = np.dtype(dtype) - size = property(lambda self: np.prod(self.shape)) + size = property(lambda self: prod(self.shape)) ndim = property(lambda self: len(self.shape)) def __len__(self): diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 950f23411..508ef6d8c 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -26,6 +26,7 @@ import jax.numpy as jnp from jax import lax from jax import ops from jax import random +from jax.util import prod def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype) def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype) @@ -41,7 +42,7 @@ def normal(stddev=1e-2, dtype=jnp.float32): return init def _compute_fans(shape, in_axis=-2, out_axis=-1): - receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] + receptive_field_size = prod(shape) / shape[in_axis] / shape[out_axis] fan_in = shape[in_axis] * receptive_field_size fan_out = shape[out_axis] * receptive_field_size return fan_in, fan_out @@ -85,7 +86,7 @@ def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32): def init(key, shape, dtype=dtype): if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") - n_rows, n_cols = np.prod(shape) // shape[column_axis], shape[column_axis] + n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis] matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) A = random.normal(key, matrix_shape, dtype) Q, R = jnp.linalg.qr(A) diff --git a/jax/random.py b/jax/random.py index d849fb988..74226903b 100644 --- a/jax/random.py +++ b/jax/random.py @@ -296,7 +296,7 @@ def _random_bits(key, bit_width, shape): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") - size = np.prod(shape) + size = prod(shape) max_count = int(np.ceil(bit_width * size / 32)) if max_count >= jnp.iinfo(np.uint32).max: # TODO(mattjj): just split the key here @@ -560,7 +560,7 @@ def choice(key, a, shape=(), replace=True, p=None): if a.ndim not in [0, 1]: raise ValueError("a must be an integer or 1-dimensional") n_inputs = int(a) if a.ndim == 0 else len(a) - n_draws = np.prod(shape).astype(int) + n_draws = prod(shape) if n_draws == 0: return jnp.zeros(shape, dtype=a.dtype) if n_inputs <= 0: diff --git a/jax/test_util.py b/jax/test_util.py index bb82f7439..244506bdb 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -33,7 +33,7 @@ from . import core from . import dtypes as _dtypes from . import lax from .config import flags, bool_env -from .util import partial +from .util import partial, prod from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce from .lib import xla_bridge from .interpreters import xla @@ -674,7 +674,7 @@ def rand_int(rng, low=0, high=None): def rand_unique_int(rng, high=None): def fn(shape, dtype): - return rng.choice(np.arange(high or np.prod(shape), dtype=dtype), + return rng.choice(np.arange(high or prod(shape), dtype=dtype), size=shape, replace=False) return fn diff --git a/jax/third_party/numpy/linalg.py b/jax/third_party/numpy/linalg.py index b145cd7eb..9d1bc3bbd 100644 --- a/jax/third_party/numpy/linalg.py +++ b/jax/third_party/numpy/linalg.py @@ -7,7 +7,7 @@ from jax.numpy._util import _wraps def _isEmpty2d(arr): # check size first for efficiency - return arr.size == 0 and jnp.product(arr.shape[-2:]) == 0 + return arr.size == 0 and np.product(arr.shape[-2:]) == 0 def _assertNoEmpty2d(*arrays): diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 4d824dde4..275635a79 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -35,6 +35,7 @@ from jax import test_util as jtu from jax.config import config from jax.experimental import host_callback as hcb from jax.lib import xla_bridge +from jax.util import prod import numpy as np config.parse_flags_with_absl() @@ -593,7 +594,7 @@ where: 10 if jtu.device_under_test() == "tpu": if dtype in (jnp.int16,): raise SkipTest(f"transfering {dtype} not supported on TPU") - args = [jnp.arange(np.prod(shape), dtype=dtype).reshape(shape)] + args = [jnp.arange(prod(shape), dtype=dtype).reshape(shape)] if nr_args > 1: args = args * nr_args jit_fun1 = api.jit(lambda xs: hcb.id_print( diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 0b14a53db..3d8b0fb90 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -30,6 +30,7 @@ from jax import dtypes from jax import lax from jax import test_util as jtu from jax.test_util import check_grads +from jax.util import prod from jax.config import config config.parse_flags_with_absl() @@ -799,7 +800,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype) + flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -817,7 +818,7 @@ class LaxAutodiffTest(jtu.JaxTestCase): for k in [1, 3] for rng_factory in [jtu.rand_default])) def testTopKGrad(self, shape, dtype, k, rng_factory): - flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype) + flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 69d6d94df..20c95c053 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -44,6 +44,7 @@ from jax import dtypes from jax import tree_util from jax.interpreters import partial_eval, xla from jax.test_util import check_grads +from jax.util import prod from jax.config import config config.parse_flags_with_absl() @@ -1315,7 +1316,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if shape in scalar_shapes or len(shape) == 0: cond_shape = (0,) elif axis is None: - cond_shape = (np.prod(shape),) + cond_shape = (prod(shape),) else: cond_shape = (shape[axis],) @@ -1353,7 +1354,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if shape in scalar_shapes or len(shape) == 0: cond_shape = (0,) elif axis is None: - cond_shape = (np.prod(shape),) + cond_shape = (prod(shape),) else: cond_shape = (shape[axis],) diff --git a/tests/lax_test.py b/tests/lax_test.py index 3fecbfdfe..a6d21611f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -32,6 +32,7 @@ from jax import test_util as jtu from jax import lax_reference from jax.test_util import check_grads import jax.util +from jax.util import prod from jax.config import config config.parse_flags_with_absl() @@ -1448,7 +1449,7 @@ class LaxTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype) + flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -1498,7 +1499,7 @@ class LaxTest(jtu.JaxTestCase): # too, since we don't guarantee the same ordering of values with equal keys. # To avoid that case, we generate unique keys (globally in the key array). def args_maker(): - flat_keys = np.arange(np.prod(shape, dtype=int), dtype=key_dtype) + flat_keys = np.arange(prod(shape), dtype=key_dtype) keys = self.rng().permutation(flat_keys).reshape(shape) values = rng(shape, val_dtype) return keys, values @@ -1517,7 +1518,7 @@ class LaxTest(jtu.JaxTestCase): for rng_factory in [jtu.rand_default])) def testTopK(self, shape, dtype, k, rng_factory): def args_maker(): - flat_values = np.arange(np.prod(shape, dtype=int), dtype=dtype) + flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) return [values] def reference_top_k(x): diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 2b13c55eb..dab2cb0a2 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -1025,7 +1025,7 @@ class PmapTest(jtu.JaxTestCase): # Manually construct a ShardedDeviceArray with the wrong sharding for the # subsequent pmap shard_shape = (3,2) - shard = jnp.arange(jnp.prod(jnp.array(shard_shape))).reshape(shard_shape) + shard = jnp.arange(prod(shard_shape)).reshape(shard_shape) bufs = [xla.device_put(shard, d) for d in xla_bridge.devices()[:4]] aval = ShapedArray((6,4), shard.dtype) sharding_spec = pxla.ShardingSpec( @@ -1620,7 +1620,7 @@ class ShardedDeviceArrayTest(jtu.JaxTestCase): if jax.device_count() < shape[0]: raise SkipTest(f"requires {shape[0]} devices") - x = jnp.arange(jnp.prod(jnp.array(shape))).reshape(shape) + x = jnp.arange(prod(shape)).reshape(shape) sharded_x = pmap(lambda x: x)(x) num_threads = 10 @@ -1816,7 +1816,7 @@ class ShardArgsTest(jtu.JaxTestCase): nshards = len(indices) if jax.device_count() < nshards: raise SkipTest - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) arg = make_arg(x) bufs = pxla.shard_args(jax.devices()[:nshards], [indices], [arg]) diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 5eced29a7..968ec748d 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -25,6 +25,7 @@ from jax import grad from jax import test_util as jtu from jax import dtypes from jax.scipy import ndimage as lsp_ndimage +from jax.util import prod from jax.config import config config.parse_flags_with_absl() @@ -83,7 +84,7 @@ class NdimageTest(jtu.JaxTestCase): mode, cval, impl, round_, rng_factory): def args_maker(): - x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) + x = np.arange(prod(shape), dtype=dtype).reshape(shape) coords = [(size - 1) * rng(coords_shape, coords_dtype) for size in shape] if round_: coords = [c.round().astype(int) for c in coords] diff --git a/tests/sharded_jit_test.py b/tests/sharded_jit_test.py index 460f50762..a9bea9892 100644 --- a/tests/sharded_jit_test.py +++ b/tests/sharded_jit_test.py @@ -31,6 +31,7 @@ from jax import tree_util from jax.interpreters import pxla from jax.interpreters.sharded_jit import sharded_jit, with_sharding_constraint from jax.interpreters.sharded_jit import PartitionSpec as P +from jax.util import prod import jax.numpy as jnp from jax.config import config @@ -53,7 +54,7 @@ class ShardedJitTest(jtu.JaxTestCase): return x + y shape = (8, 8) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(prod(shape), dtype=np.float32).reshape(shape) actual = f(x, x + 1) expected = x + (x + 1) self.assertAllClose(actual, expected, check_dtypes=False) @@ -72,7 +73,7 @@ class ShardedJitTest(jtu.JaxTestCase): return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): - return np.arange(np.prod(shape)).reshape(shape) + return np.arange(prod(shape)).reshape(shape) a = (_make_arg(4, 4), 1) b = _make_arg(4, 4) @@ -102,7 +103,7 @@ class ShardedJitTest(jtu.JaxTestCase): return x + 1, ((x + 2, x + 3), x + 4) shape = (4, 4) - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) @@ -134,7 +135,7 @@ class ShardedJitTest(jtu.JaxTestCase): return y * 2 shape = (8, 8) - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) expected = (x + 1) * 2 # Matching sharded_jit partitions @@ -173,7 +174,7 @@ class ShardedJitTest(jtu.JaxTestCase): lambda i: with_sharding_constraint(i + 1., P(2, 1)), x) - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(prod(shape), dtype=np.float32).reshape(shape) expected = x + 10. actual = sharded_jit(f, in_parts=None, out_parts=None)(x) self.assertAllClose(actual, expected, check_dtypes=False) @@ -195,7 +196,7 @@ class ShardedJitTest(jtu.JaxTestCase): return vjp_f(p) shape = (4, 4) - x = jnp.arange(jnp.prod(shape), dtype=jnp.float32).reshape(shape) + x = jnp.arange(prod(shape), dtype=jnp.float32).reshape(shape) actual = f(x) expected = expected_f(x) self.assertAllClose(actual, expected, check_dtypes=False) @@ -222,7 +223,7 @@ class ShardedJitTest(jtu.JaxTestCase): (y, z), token = lax.infeed(token, infeed_shapes, partitions=infeed_parts) return x @ y.T + z - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = np.arange(prod(shape), dtype=np.float32).reshape(shape) y = x + 1 shard_size = shape[0] // jax.local_device_count() y_shards = [y[i:i+shard_size] for i in range(0, shape[0], shard_size)] @@ -303,7 +304,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): if num_shards > jax.local_device_count(): raise SkipTest("requires %d devices" % num_shards) - x = np.arange(np.prod(shape, dtype=dtype)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) y = x + 1 result = pmap( sharded_jit(f, in_parts=in_partitions, out_parts=out_partitions))(x, y) @@ -395,7 +396,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): return a1 + a2 + b + c1 + c2 + c3 def _make_arg(*shape): - return np.arange(np.prod(shape)).reshape(shape) + return np.arange(prod(shape)).reshape(shape) a = (_make_arg(2, 4, 4), _make_arg(2)) b = _make_arg(2, 4, 4) @@ -419,7 +420,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): return x + 1, ((x + 2, x + 3), x + 4) shape = (2, 4, 4) - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) in_parts = (P(2, 1),) out_parts = (P(2, 1), ((P(1, 2), None), P(2, 1))) @@ -438,7 +439,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): return jnp.sum(args) shape = (2, 4, 4) - args = [np.arange(np.prod(shape)).reshape(shape)] * num_args + args = [np.arange(prod(shape)).reshape(shape)] * num_args in_partitions = (P(2, 1),) * num_args out_partitions = None result = pmap(sharded_jit( @@ -463,7 +464,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): return jnp.dot(x, x) * 2 shape = (2, 8, 8) - x = np.arange(np.prod(shape)).reshape(shape) + x = np.arange(prod(shape)).reshape(shape) result = pmap(f)(x) expected = pmap(expected_f)(x) @@ -477,7 +478,7 @@ class PmapOfShardedJitTest(jtu.JaxTestCase): in_partitions = (P(2, 1), None, None) out_partitions = P(2, 1) in_axes = (None, None, 0) - x = y = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + x = y = np.arange(prod(shape), dtype=np.float32).reshape(shape) dummy = np.arange(replicas, dtype=np.float32) + 1 num_shards = replicas * np.prod(in_partitions[0]) if num_shards > jax.local_device_count():