mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
This commit is contained in:
parent
107689e91f
commit
141fabbbf5
@ -26,6 +26,8 @@ Operators
|
||||
abs
|
||||
add
|
||||
acos
|
||||
argmax
|
||||
argmin
|
||||
asin
|
||||
atan
|
||||
atan2
|
||||
|
@ -721,6 +721,14 @@ tf_impl[lax.reduce_min_p] = (
|
||||
tf_impl[lax.reduce_or_p] = axes_to_axis(tf.reduce_any)
|
||||
tf_impl[lax.reduce_and_p] = axes_to_axis(tf.reduce_all)
|
||||
|
||||
def _argminmax(fn, operand, axes, index_dtype):
|
||||
axis, = axes
|
||||
# TODO(phawkins): handle axes larger than 2^31.
|
||||
return fn(operand, axis=axis, output_type=tf.dtypes.as_dtype(index_dtype))
|
||||
|
||||
tf_impl[lax.argmin_p] = functools.partial(_argminmax, tf.math.argmin)
|
||||
tf_impl[lax.argmax_p] = functools.partial(_argminmax, tf.math.argmax)
|
||||
|
||||
|
||||
_add_fn = tf.function(tf.math.add)
|
||||
_ge_fn = tf.function(tf.math.greater_equal)
|
||||
|
@ -34,6 +34,10 @@ from .lax import (
|
||||
after_all,
|
||||
after_all_p,
|
||||
and_p,
|
||||
argmax,
|
||||
argmax_p,
|
||||
argmin,
|
||||
argmin_p,
|
||||
asin,
|
||||
asinh,
|
||||
asinh_p,
|
||||
|
@ -1028,6 +1028,18 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
||||
else:
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
|
||||
def argmin(operand: Array, axis: int,
|
||||
index_dtype: DType) -> Tuple[Array, Array]:
|
||||
"""Computes the index of the minimum element along ``axis``."""
|
||||
return argmin_p.bind(operand, axes=(axis,),
|
||||
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
||||
|
||||
def argmax(operand: Array, axis: int,
|
||||
index_dtype: DType) -> Tuple[Array, Array]:
|
||||
"""Computes the index of the maximum element along ``axis``."""
|
||||
return argmax_p.bind(operand, axes=(axis,),
|
||||
index_dtype=dtypes.canonicalize_dtype(index_dtype))
|
||||
|
||||
def reduce(operand: Array, init_value: Array, computation: Callable,
|
||||
dimensions: Sequence[int]) -> Array:
|
||||
"""Wraps XLA's `Reduce
|
||||
@ -4294,6 +4306,73 @@ _masking_defreducer(reduce_min_p,
|
||||
lambda shape, dtype: onp.broadcast_to(onp.array(onp.inf, dtype), shape))
|
||||
|
||||
|
||||
|
||||
def _argminmax_shape_rule(operand, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
return tuple(onp.delete(operand.shape, axis))
|
||||
|
||||
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
|
||||
return index_dtype
|
||||
|
||||
def _argminmax_translation_rule(value_comparator, identity,
|
||||
c, operand, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
shape = c.get_shape(operand)
|
||||
dtype = shape.numpy_dtype()
|
||||
|
||||
subc = xb.make_computation_builder("argminmax_comparator")
|
||||
value_shape = xc.Shape.array_shape(shape.xla_element_type(), ())
|
||||
index_shape = xc.Shape.array_shape(index_dtype, ())
|
||||
x_value = xb.parameter(subc, 0, value_shape)
|
||||
x_index = xb.parameter(subc, 1, index_shape)
|
||||
y_value = xb.parameter(subc, 2, value_shape)
|
||||
y_index = xb.parameter(subc, 3, index_shape)
|
||||
which_value = value_comparator(x_value, y_value)
|
||||
which_index = xops.Or(which_value, xops.And(xops.Eq(x_value, y_value),
|
||||
xops.Lt(x_index, y_index)))
|
||||
xops.Tuple(subc, [xops.Select(which_value, x_value, y_value),
|
||||
xops.Select(which_index, x_index, y_index)])
|
||||
comparator = subc.build()
|
||||
|
||||
iota_shape = xc.Shape.array_shape(index_dtype, shape.dimensions())
|
||||
iota = xc.ops.Iota(c, iota_shape, axis)
|
||||
out = xops.Reduce(
|
||||
c, [operand, iota],
|
||||
[xb.constant(c, identity(dtype)),
|
||||
xb.constant(c, onp.array(0, index_dtype))], comparator, [axis])
|
||||
return xops.GetTupleElement(out, 1)
|
||||
|
||||
def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
|
||||
axis, = axes
|
||||
idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
|
||||
maxval = onp.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
|
||||
maxval = broadcast(tie_in(a, maxval), a.shape)
|
||||
mask_idxs = select(eq(a, expand_dims(op(a, (axis,)), (axis,))), idxs,
|
||||
maxval)
|
||||
return _reduce_min(mask_idxs, (axis,))
|
||||
|
||||
_argmin_translation_rule = partial(_argminmax_translation_rule, xops.Lt,
|
||||
_get_min_identity)
|
||||
_argmax_translation_rule = partial(_argminmax_translation_rule, xops.Gt,
|
||||
_get_max_identity)
|
||||
|
||||
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmin', _argmin_translation_rule)
|
||||
batching.defreducer(argmin_p)
|
||||
ad.defjvp_zero(argmin_p)
|
||||
xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun(
|
||||
partial(_argminmax_gpu_translation_rule, _reduce_min),
|
||||
multiple_results=False)
|
||||
|
||||
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmax', _argmax_translation_rule)
|
||||
batching.defreducer(argmax_p)
|
||||
ad.defjvp_zero(argmax_p)
|
||||
xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun(
|
||||
partial(_argminmax_gpu_translation_rule, _reduce_max),
|
||||
multiple_results=False)
|
||||
|
||||
|
||||
def _reduce_logical_shape_rule(operand, *, axes):
|
||||
if operand.dtype != onp.bool_:
|
||||
msg = "logical reduction requires operand dtype bool, got {}."
|
||||
|
@ -3094,7 +3094,18 @@ def argmax(a, axis=None):
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
return _argminmax("argmax", max, a, axis)
|
||||
if a.shape[axis] == 0:
|
||||
raise ValueError("attempt to get argmax of an empty sequence")
|
||||
return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
|
||||
|
||||
@_wraps(np.argmin)
|
||||
def argmin(a, axis=None):
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
if a.shape[axis] == 0:
|
||||
raise ValueError("attempt to get argmin of an empty sequence")
|
||||
return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64)
|
||||
|
||||
|
||||
_NANARG_DOC = """\
|
||||
@ -3111,15 +3122,6 @@ def nanargmax(a, axis=None):
|
||||
res = argmax(a, axis=axis)
|
||||
return where(all(nan_mask, axis=axis), -1, res)
|
||||
|
||||
|
||||
@_wraps(np.argmin)
|
||||
def argmin(a, axis=None):
|
||||
if axis is None:
|
||||
a = ravel(a)
|
||||
axis = 0
|
||||
return _argminmax("argmin", min, a, axis)
|
||||
|
||||
|
||||
@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"))
|
||||
def nanargmin(a, axis=None):
|
||||
if not issubdtype(_dtype(a), inexact):
|
||||
@ -3130,19 +3132,6 @@ def nanargmin(a, axis=None):
|
||||
return where(all(nan_mask, axis=axis), -1, res)
|
||||
|
||||
|
||||
# TODO(mattjj): redo this lowering with a call to variadic lax.reduce
|
||||
def _argminmax(name, op, a, axis):
|
||||
if a.shape[axis] == 0:
|
||||
raise ValueError("attempt to get {} of an empty sequence".format(name))
|
||||
shape = [1] * a.ndim
|
||||
shape[axis] = a.shape[axis]
|
||||
idxs = lax.tie_in(a, arange(a.shape[axis])).reshape(shape)
|
||||
maxval = iinfo(dtypes.canonicalize_dtype(idxs.dtype)).max
|
||||
maxval = lax.tie_in(a, maxval)
|
||||
mask_idxs = where(lax._eq_meet(a, op(a, axis, keepdims=True)), idxs, maxval)
|
||||
return min(mask_idxs, axis)
|
||||
|
||||
|
||||
@_wraps(np.sort)
|
||||
def sort(a, axis=-1, kind='quicksort', order=None):
|
||||
if kind != 'quicksort':
|
||||
|
@ -306,10 +306,10 @@ JAX_REDUCER_NO_DTYPE_RECORDS = [
|
||||
]
|
||||
|
||||
JAX_ARGMINMAX_RECORDS = [
|
||||
op_record("argmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
||||
op_record("argmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
||||
op_record("nanargmin", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
op_record("nanargmax", 1, all_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
||||
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
|
||||
op_record("nanargmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
op_record("nanargmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_nan, []),
|
||||
]
|
||||
|
||||
JAX_OPERATOR_OVERLOADS = [
|
||||
|
@ -482,6 +482,22 @@ class LaxVmapTest(jtu.JaxTestCase):
|
||||
fun = lambda operand: lax.reduce(operand, init_val, op, dims)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}_bdims={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dim,
|
||||
bdims),
|
||||
"op": op, "shape": shape, "dtype": dtype,
|
||||
"dim": dim, "bdims": bdims}
|
||||
for op in [lax.argmin, lax.argmax]
|
||||
for dtype in default_dtypes
|
||||
for shape in [(3, 4, 5)]
|
||||
for dim in range(len(shape))
|
||||
for bdims in all_bdims(shape)))
|
||||
def testArgminmax(self, op, shape, dtype, dim, bdims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
fun = lambda operand: op(operand, dim, onp.int32)
|
||||
self._CheckBatching(fun, 5, bdims, (shape,), (dtype,), rng)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_dtype={}_padding={}"
|
||||
.format(op.__name__, onp.dtype(dtype).name, padding),
|
||||
|
Loading…
x
Reference in New Issue
Block a user