source sync

PiperOrigin-RevId: 222923229
This commit is contained in:
Dougal Maclaurin 2018-11-26 18:50:27 -08:00 committed by Roy Frostig
parent 599ea38175
commit ca2634ea5d
5 changed files with 15 additions and 7 deletions

View File

@ -21,6 +21,7 @@ import six
from . import core
from . import ad_util
from . util import prod
from .lib import xla_bridge
@ -80,7 +81,7 @@ class ShapedArray(UnshapedArray):
self.shape = shape
ndim = property(lambda self: len(self.shape))
size = property(lambda self: int(onp.prod(self.shape)))
size = property(lambda self: prod(self.shape))
def __eq__(self, other):
return (type(self) is type(other)

View File

@ -28,7 +28,7 @@ from .. import core
from .. import ad_util
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
from ..core import AbstractTuple, JaxTuple, pack, valid_jaxtype
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map, prod
from ..linear_util import transformation_with_aux, memoize as linear_memoize
from ..lib import xla_bridge as xb
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
@ -190,7 +190,7 @@ class DeviceArray(DeviceValue):
self.shape = xla_shape.dimensions()
self.dtype = xla_shape.element_type()
self.ndim = len(self.shape)
size = int(onp.prod(self.shape))
size = prod(self.shape)
self._npy_value = None
@property

View File

@ -37,7 +37,7 @@ from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import ad
from .interpreters import batching
from .util import curry, safe_zip, unzip2
from .util import curry, safe_zip, unzip2, prod
from .tree_util import build_tree
from .lib import xla_bridge
@ -424,7 +424,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
def collapse(operand, start_dimension, stop_dimension):
lo, hi = start_dimension, stop_dimension
size = onp.product(operand.shape[lo:hi])
size = prod(operand.shape[lo:hi])
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
return reshape(operand, new_shape)
@ -1407,7 +1407,7 @@ def reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs):
if not onp.all(onp.greater_equal(new_sizes, 0)):
msg = 'reshape new_sizes must all be positive, got {}.'
raise TypeError(msg.format(new_sizes))
if onp.prod(onp.shape(operand)) != onp.prod(new_sizes):
if prod(onp.shape(operand)) != prod(new_sizes):
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
if dimensions is not None:

View File

@ -26,6 +26,7 @@ from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..interpreters.xla import DeviceArray
from ..lib import xla_bridge
import jax.lax as lax
from ..util import memoize
# To provide the same module-level names as Numpy, we need to redefine builtins
# and also use some common names (like 'shape' and 'dtype') at the top-level.
@ -105,6 +106,7 @@ def _promote_shapes(*args):
if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
@memoize
def _broadcast_shapes(*shapes):
"""Apply Numpy broadcasting rules to the given shapes."""
if len(shapes) == 1:
@ -120,6 +122,7 @@ def _broadcast_shapes(*shapes):
def _promote_dtypes(*args):
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
else:

View File

@ -18,7 +18,7 @@ from __future__ import print_function
import functools
import itertools as it
from operator import mul
allow_memoize_hash_failures = False
@ -139,3 +139,7 @@ def memoize(fun):
else:
raise
return memoized_fun
def prod(xs):
return functools.reduce(mul, xs, 1)