mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
87ee4b7c56
commit
693365c239
17
jax/lax.py
17
jax/lax.py
@ -111,7 +111,12 @@ def convert_element_type(operand, new_dtype):
|
||||
return operand
|
||||
|
||||
def bitcast_convert_type(operand, new_dtype):
|
||||
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
||||
new_dtype = xla_bridge.canonicalize_dtype(new_dtype)
|
||||
old_dtype = _dtype(operand)
|
||||
if old_dtype != new_dtype:
|
||||
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
|
||||
else:
|
||||
return operand
|
||||
|
||||
def clamp(min, operand, max):
|
||||
return clamp_p.bind(min, operand, max)
|
||||
@ -269,9 +274,9 @@ def _get_monoid_reducer(monoid_op, x):
|
||||
if (type(aval) is ConcreteArray) and aval.shape == ():
|
||||
if monoid_op is add:
|
||||
return aval.val == 0 and _reduce_sum
|
||||
elif monoid_op is max:
|
||||
elif monoid_op is max or monoid_op is bitwise_or and aval.dtype == onp.bool_:
|
||||
return aval.val == _get_max_identity(aval.dtype) and _reduce_max
|
||||
elif monoid_op is min:
|
||||
elif monoid_op is min or monoid_op is bitwise_and and aval.dtype == onp.bool_:
|
||||
return aval.val == _get_min_identity(aval.dtype) and _reduce_min
|
||||
|
||||
def _get_max_identity(dtype):
|
||||
@ -279,12 +284,16 @@ def _get_max_identity(dtype):
|
||||
return onp.array(-onp.inf, dtype)
|
||||
elif onp.issubdtype(dtype, onp.integer):
|
||||
return onp.array(onp.iinfo(dtype).min, dtype)
|
||||
elif onp.issubdtype(dtype, onp.bool_):
|
||||
return onp.array(False, onp.bool_)
|
||||
|
||||
def _get_min_identity(dtype):
|
||||
if onp.issubdtype(dtype, onp.floating):
|
||||
return onp.array(onp.inf, dtype)
|
||||
elif onp.issubdtype(dtype, onp.integer):
|
||||
return onp.array(onp.iinfo(dtype).max, dtype)
|
||||
elif onp.issubdtype(dtype, onp.bool_):
|
||||
return onp.array(True, onp.bool_)
|
||||
|
||||
def _reduce_sum(operand, axes):
|
||||
return reduce_sum_p.bind(operand, axes=tuple(axes), input_shape=operand.shape)
|
||||
@ -1828,7 +1837,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
|
||||
|
||||
reduce_p = standard_primitive(reduce_shape_rule, _input_dtype, 'reduce',
|
||||
reduce_translation_rule)
|
||||
batching.defreducer(reduce_p)
|
||||
# batching.defreducer(reduce_p) # TODO batching rule for general reduce
|
||||
|
||||
|
||||
def reduce_sum_shape_rule(operand, axes, input_shape):
|
||||
|
@ -25,8 +25,7 @@ from .. import core
|
||||
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from ..interpreters.xla import DeviceArray
|
||||
from .. import lax
|
||||
from ..util import memoize
|
||||
from ..util import get_module_functions
|
||||
from ..util import memoize, partial, get_module_functions
|
||||
from ..lib import xla_bridge
|
||||
|
||||
# To provide the same module-level names as Numpy, we need to redefine builtins
|
||||
@ -591,15 +590,16 @@ around = round
|
||||
### Reducers
|
||||
|
||||
|
||||
def _make_reduction(np_fun, op, init_val):
|
||||
def _make_reduction(np_fun, op, init_val, preproc=None):
|
||||
"""Creates reduction function given a binary operation and monoid identity."""
|
||||
|
||||
@_wraps(op)
|
||||
@_wraps(np_fun)
|
||||
def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
|
||||
if out is not None:
|
||||
raise ValueError("reduction does not support `out` argument.")
|
||||
|
||||
a = a if isinstance(a, ndarray) else asarray(a)
|
||||
a = preproc(a) if preproc else a
|
||||
dims = _reduction_dims(a, axis)
|
||||
result_dtype = _dtype(np_fun(onp.ones((), dtype=dtype or _dtype(a))))
|
||||
if _dtype(a) != result_dtype:
|
||||
@ -614,7 +614,6 @@ def _make_reduction(np_fun, op, init_val):
|
||||
|
||||
return reduction
|
||||
|
||||
|
||||
def _reduction_dims(a, axis):
|
||||
if axis is None:
|
||||
return onp.arange(ndim(a))
|
||||
@ -625,7 +624,6 @@ def _reduction_dims(a, axis):
|
||||
else:
|
||||
raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
|
||||
|
||||
|
||||
def _reduction_init_val(a, init_val):
|
||||
a_dtype = xla_bridge.canonicalize_dtype(_dtype(a))
|
||||
try:
|
||||
@ -635,13 +633,14 @@ def _reduction_init_val(a, init_val):
|
||||
sign, iinfo = onp.sign(init_val), onp.iinfo(a_dtype)
|
||||
return onp.array(iinfo.min if sign < 0 else iinfo.max, dtype=a_dtype)
|
||||
|
||||
_cast_to_bool = partial(lax.convert_element_type, new_dtype=onp.bool_)
|
||||
|
||||
sum = _make_reduction(onp.sum, lax.add, 0)
|
||||
prod = _make_reduction(onp.prod, lax.mul, 1)
|
||||
max = _make_reduction(onp.max, lax.max, -onp.inf)
|
||||
min = _make_reduction(onp.min, lax.min, onp.inf)
|
||||
all = alltrue = _make_reduction(onp.all, logical_and, True)
|
||||
any = sometrue = _make_reduction(onp.any, logical_or, False)
|
||||
all = alltrue = _make_reduction(onp.all, lax.bitwise_and, True, _cast_to_bool)
|
||||
any = sometrue = _make_reduction(onp.any, lax.bitwise_or, False, _cast_to_bool)
|
||||
|
||||
|
||||
@_wraps(onp.mean)
|
||||
|
@ -284,6 +284,13 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
jacrev(func)(xs) # don't crash
|
||||
jacfwd(func)(xs) # don't crash
|
||||
|
||||
def testAny(self):
|
||||
# test modeling the code in https://github.com/google/jax/issues/108
|
||||
|
||||
ans = vmap(np.any)(np.array([[True, False], [False, False]]))
|
||||
expected = np.array([True, False])
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
@ -142,8 +142,8 @@ JAX_REDUCER_RECORDS = [
|
||||
]
|
||||
|
||||
JAX_REDUCER_NO_DTYPE_RECORDS = [
|
||||
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_default(), []),
|
||||
op_record("all", 1, default_dtypes + bool_dtypes, all_shapes, jtu.rand_some_zero(), []),
|
||||
op_record("any", 1, default_dtypes + bool_dtypes, all_shapes, jtu.rand_some_zero(), []),
|
||||
op_record("max", 1, default_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
||||
op_record("min", 1, default_dtypes, nonempty_shapes, jtu.rand_default(), []),
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user