np.all and np.any should lead to monoid reducers

fixes #108
This commit is contained in:
Matthew Johnson 2018-12-14 08:07:12 -08:00
parent 87ee4b7c56
commit 693365c239
4 changed files with 29 additions and 14 deletions

View File

@ -111,7 +111,12 @@ def convert_element_type(operand, new_dtype):
return operand
def bitcast_convert_type(operand, new_dtype):
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
new_dtype = xla_bridge.canonicalize_dtype(new_dtype)
old_dtype = _dtype(operand)
if old_dtype != new_dtype:
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)
else:
return operand
def clamp(min, operand, max):
return clamp_p.bind(min, operand, max)
@ -269,9 +274,9 @@ def _get_monoid_reducer(monoid_op, x):
if (type(aval) is ConcreteArray) and aval.shape == ():
if monoid_op is add:
return aval.val == 0 and _reduce_sum
elif monoid_op is max:
elif monoid_op is max or monoid_op is bitwise_or and aval.dtype == onp.bool_:
return aval.val == _get_max_identity(aval.dtype) and _reduce_max
elif monoid_op is min:
elif monoid_op is min or monoid_op is bitwise_and and aval.dtype == onp.bool_:
return aval.val == _get_min_identity(aval.dtype) and _reduce_min
def _get_max_identity(dtype):
@ -279,12 +284,16 @@ def _get_max_identity(dtype):
return onp.array(-onp.inf, dtype)
elif onp.issubdtype(dtype, onp.integer):
return onp.array(onp.iinfo(dtype).min, dtype)
elif onp.issubdtype(dtype, onp.bool_):
return onp.array(False, onp.bool_)
def _get_min_identity(dtype):
if onp.issubdtype(dtype, onp.floating):
return onp.array(onp.inf, dtype)
elif onp.issubdtype(dtype, onp.integer):
return onp.array(onp.iinfo(dtype).max, dtype)
elif onp.issubdtype(dtype, onp.bool_):
return onp.array(True, onp.bool_)
def _reduce_sum(operand, axes):
return reduce_sum_p.bind(operand, axes=tuple(axes), input_shape=operand.shape)
@ -1828,7 +1837,7 @@ def _reduction_computation(c, jaxpr, consts, init_value):
reduce_p = standard_primitive(reduce_shape_rule, _input_dtype, 'reduce',
reduce_translation_rule)
batching.defreducer(reduce_p)
# batching.defreducer(reduce_p) # TODO batching rule for general reduce
def reduce_sum_shape_rule(operand, axes, input_shape):

View File

@ -25,8 +25,7 @@ from .. import core
from ..abstract_arrays import UnshapedArray, ShapedArray, ConcreteArray
from ..interpreters.xla import DeviceArray
from .. import lax
from ..util import memoize
from ..util import get_module_functions
from ..util import memoize, partial, get_module_functions
from ..lib import xla_bridge
# To provide the same module-level names as Numpy, we need to redefine builtins
@ -591,15 +590,16 @@ around = round
### Reducers
def _make_reduction(np_fun, op, init_val):
def _make_reduction(np_fun, op, init_val, preproc=None):
"""Creates reduction function given a binary operation and monoid identity."""
@_wraps(op)
@_wraps(np_fun)
def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
if out is not None:
raise ValueError("reduction does not support `out` argument.")
a = a if isinstance(a, ndarray) else asarray(a)
a = preproc(a) if preproc else a
dims = _reduction_dims(a, axis)
result_dtype = _dtype(np_fun(onp.ones((), dtype=dtype or _dtype(a))))
if _dtype(a) != result_dtype:
@ -614,7 +614,6 @@ def _make_reduction(np_fun, op, init_val):
return reduction
def _reduction_dims(a, axis):
if axis is None:
return onp.arange(ndim(a))
@ -625,7 +624,6 @@ def _reduction_dims(a, axis):
else:
raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))
def _reduction_init_val(a, init_val):
a_dtype = xla_bridge.canonicalize_dtype(_dtype(a))
try:
@ -635,13 +633,14 @@ def _reduction_init_val(a, init_val):
sign, iinfo = onp.sign(init_val), onp.iinfo(a_dtype)
return onp.array(iinfo.min if sign < 0 else iinfo.max, dtype=a_dtype)
_cast_to_bool = partial(lax.convert_element_type, new_dtype=onp.bool_)
sum = _make_reduction(onp.sum, lax.add, 0)
prod = _make_reduction(onp.prod, lax.mul, 1)
max = _make_reduction(onp.max, lax.max, -onp.inf)
min = _make_reduction(onp.min, lax.min, onp.inf)
all = alltrue = _make_reduction(onp.all, logical_and, True)
any = sometrue = _make_reduction(onp.any, logical_or, False)
all = alltrue = _make_reduction(onp.all, lax.bitwise_and, True, _cast_to_bool)
any = sometrue = _make_reduction(onp.any, lax.bitwise_or, False, _cast_to_bool)
@_wraps(onp.mean)

View File

@ -284,6 +284,13 @@ class BatchingTest(jtu.JaxTestCase):
jacrev(func)(xs) # don't crash
jacfwd(func)(xs) # don't crash
def testAny(self):
# test modeling the code in https://github.com/google/jax/issues/108
ans = vmap(np.any)(np.array([[True, False], [False, False]]))
expected = np.array([True, False])
self.assertAllClose(ans, expected, check_dtypes=True)
if __name__ == '__main__':
absltest.main()

View File

@ -142,8 +142,8 @@ JAX_REDUCER_RECORDS = [
]
JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_default(), []),
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_default(), []),
op_record("all", 1, default_dtypes + bool_dtypes, all_shapes, jtu.rand_some_zero(), []),
op_record("any", 1, default_dtypes + bool_dtypes, all_shapes, jtu.rand_some_zero(), []),
op_record("max", 1, default_dtypes, nonempty_shapes, jtu.rand_default(), []),
op_record("min", 1, default_dtypes, nonempty_shapes, jtu.rand_default(), []),
]