mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
source sync
PiperOrigin-RevId: 222923229
This commit is contained in:
parent
599ea38175
commit
ca2634ea5d
@ -21,6 +21,7 @@ import six
|
||||
|
||||
from . import core
|
||||
from . import ad_util
|
||||
from . util import prod
|
||||
from .lib import xla_bridge
|
||||
|
||||
|
||||
@ -80,7 +81,7 @@ class ShapedArray(UnshapedArray):
|
||||
self.shape = shape
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: int(onp.prod(self.shape)))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
|
@ -28,7 +28,7 @@ from .. import core
|
||||
from .. import ad_util
|
||||
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
|
||||
from ..core import AbstractTuple, JaxTuple, pack, valid_jaxtype
|
||||
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map
|
||||
from ..util import partial, partialmethod, memoize, unzip2, concatenate, safe_map, prod
|
||||
from ..linear_util import transformation_with_aux, memoize as linear_memoize
|
||||
from ..lib import xla_bridge as xb
|
||||
from .partial_eval import trace_to_subjaxpr, merge_pvals, JaxprTrace, PartialVal
|
||||
@ -190,7 +190,7 @@ class DeviceArray(DeviceValue):
|
||||
self.shape = xla_shape.dimensions()
|
||||
self.dtype = xla_shape.element_type()
|
||||
self.ndim = len(self.shape)
|
||||
size = int(onp.prod(self.shape))
|
||||
size = prod(self.shape)
|
||||
self._npy_value = None
|
||||
|
||||
@property
|
||||
|
@ -37,7 +37,7 @@ from .interpreters import partial_eval as pe
|
||||
from .interpreters import xla
|
||||
from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
from .util import curry, safe_zip, unzip2
|
||||
from .util import curry, safe_zip, unzip2, prod
|
||||
from .tree_util import build_tree
|
||||
from .lib import xla_bridge
|
||||
|
||||
@ -424,7 +424,7 @@ def full_like(x, fill_value, dtype=None, shape=None):
|
||||
|
||||
def collapse(operand, start_dimension, stop_dimension):
|
||||
lo, hi = start_dimension, stop_dimension
|
||||
size = onp.product(operand.shape[lo:hi])
|
||||
size = prod(operand.shape[lo:hi])
|
||||
new_shape = operand.shape[:lo] + (size,) + operand.shape[hi:]
|
||||
return reshape(operand, new_shape)
|
||||
|
||||
@ -1407,7 +1407,7 @@ def reshape_shape_rule(operand, new_sizes, dimensions, **unused_kwargs):
|
||||
if not onp.all(onp.greater_equal(new_sizes, 0)):
|
||||
msg = 'reshape new_sizes must all be positive, got {}.'
|
||||
raise TypeError(msg.format(new_sizes))
|
||||
if onp.prod(onp.shape(operand)) != onp.prod(new_sizes):
|
||||
if prod(onp.shape(operand)) != prod(new_sizes):
|
||||
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
|
||||
raise TypeError(msg.format(new_sizes, onp.shape(operand)))
|
||||
if dimensions is not None:
|
||||
|
@ -26,6 +26,7 @@ from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from ..lib import xla_bridge
|
||||
import jax.lax as lax
|
||||
from ..util import memoize
|
||||
|
||||
# To provide the same module-level names as Numpy, we need to redefine builtins
|
||||
# and also use some common names (like 'shape' and 'dtype') at the top-level.
|
||||
@ -105,6 +106,7 @@ def _promote_shapes(*args):
|
||||
if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
|
||||
|
||||
|
||||
@memoize
|
||||
def _broadcast_shapes(*shapes):
|
||||
"""Apply Numpy broadcasting rules to the given shapes."""
|
||||
if len(shapes) == 1:
|
||||
@ -120,6 +122,7 @@ def _broadcast_shapes(*shapes):
|
||||
|
||||
def _promote_dtypes(*args):
|
||||
"""Convenience function to apply Numpy argument dtype promotion."""
|
||||
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
|
||||
if len(args) < 2:
|
||||
return args
|
||||
else:
|
||||
|
@ -18,7 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import itertools as it
|
||||
|
||||
from operator import mul
|
||||
|
||||
allow_memoize_hash_failures = False
|
||||
|
||||
@ -139,3 +139,7 @@ def memoize(fun):
|
||||
else:
|
||||
raise
|
||||
return memoized_fun
|
||||
|
||||
|
||||
def prod(xs):
|
||||
return functools.reduce(mul, xs, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user