mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
This reverts commit 99401c5a844cc19c6ce66cc26997f999c9ecf6d8.
This commit is contained in:
parent
fe1e1041b6
commit
71461a37f3
103
jax/lax/lax.py
103
jax/lax/lax.py
@ -1020,44 +1020,25 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
|
||||
else:
|
||||
return transpose_p.bind(operand, permutation=permutation)
|
||||
|
||||
def reduce(operand: Union[Array, Sequence[Array]],
|
||||
init_value: Union[Array, Sequence[Array]],
|
||||
computation: Callable,
|
||||
dimensions: Sequence[int]) -> Union[Array, Tuple[Array, ...]]:
|
||||
def reduce(operand: Array, init_value: Array, computation: Callable,
|
||||
dimensions: Sequence[int]) -> Array:
|
||||
"""Wraps XLA's `Reduce
|
||||
<https://www.tensorflow.org/xla/operation_semantics#reduce>`_
|
||||
operator.
|
||||
"""
|
||||
return_tuple = isinstance(operand, Sequence)
|
||||
if not return_tuple:
|
||||
operand = (operand,)
|
||||
if not isinstance(init_value, Sequence):
|
||||
init_value = (init_value,)
|
||||
if len(operand) == 0:
|
||||
raise TypeError("reduce requires at least one operand")
|
||||
if len(operand) != len(init_value):
|
||||
raise TypeError("reduce: length of operands tuple must match length of init_values tuple; got "
|
||||
f"len(operand)={len(operand)}, len(init_value)={len(init_value)}.")
|
||||
|
||||
monoid_reducer = _get_monoid_reducer(computation, init_value[0])
|
||||
if len(operand) == 1 and monoid_reducer:
|
||||
out = (monoid_reducer(operand[0], dimensions),)
|
||||
monoid_reducer = _get_monoid_reducer(computation, init_value)
|
||||
if monoid_reducer:
|
||||
return monoid_reducer(operand, dimensions)
|
||||
else:
|
||||
jaxpr, consts = _reduction_jaxpr(computation, *(_abstractify(v) for v in init_value))
|
||||
# TODO(mattjj): handle consts correctly
|
||||
# TODO(mattjj): don't pass computation
|
||||
out = reduce_p.bind(*operand, *init_value, computation=computation,
|
||||
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
|
||||
return tuple(out) if return_tuple else out[0]
|
||||
jaxpr, consts = _reduction_jaxpr(computation, _abstractify(init_value))
|
||||
return reduce_p.bind(operand, init_value, computation=computation,
|
||||
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
|
||||
|
||||
@cache()
|
||||
def _reduction_jaxpr(computation, *avals):
|
||||
pvals = tuple(pe.PartialVal.unknown(aval) for aval in avals)
|
||||
if len(pvals) == 1:
|
||||
comp = lu.wrap_init(lambda x, y: (computation(x, y),))
|
||||
else:
|
||||
comp = lu.wrap_init(computation)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(comp, 2 * pvals, instantiate=True)
|
||||
def _reduction_jaxpr(computation, aval):
|
||||
pval = pe.PartialVal.unknown(aval)
|
||||
comp = lu.wrap_init(lambda x, y: (computation(x, y),))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr(comp, (pval, pval), instantiate=False)
|
||||
return jaxpr, consts
|
||||
|
||||
def _get_monoid_reducer(monoid_op: Callable, x: Array) -> Optional[Callable]:
|
||||
@ -4097,50 +4078,26 @@ batching.primitive_batchers[scatter_p] = (
|
||||
partial(_scatter_batching_rule, scatter))
|
||||
|
||||
|
||||
def _reduce_abstract_eval(*args, dimensions, **kwargs):
|
||||
N = len(args) // 2
|
||||
operands, init_values = args[:N], args[N:]
|
||||
if len(operands) != len(init_values):
|
||||
raise TypeError("Expected number of operands to equal number of init_values; "
|
||||
f"got {len(operands)} and {len(init_values)}")
|
||||
if any(operand.shape != operands[0].shape for operand in operands[1:]):
|
||||
shapes = " ".join(str(operand.shape) for operand in operands)
|
||||
raise TypeError(f"Arguments to reduce must have equal shapes, got: {shapes}")
|
||||
shape = tuple(onp.delete(operands[0].shape, dimensions))
|
||||
return tuple(
|
||||
ShapedArray(shape, dtype=dtypes.canonicalize_dtype(operand.dtype))
|
||||
for operand in operands
|
||||
)
|
||||
def _reduce_shape_rule(operand, init_value, *, computation, jaxpr, consts,
|
||||
dimensions):
|
||||
return tuple(onp.delete(operand.shape, dimensions))
|
||||
|
||||
def _reduce_translation_rule(c, *args, computation, jaxpr, consts, dimensions):
|
||||
N = len(args) // 2
|
||||
operands, init_values = args[:N], args[N:]
|
||||
assert len(operands) == len(init_values)
|
||||
shapes = [c.get_shape(v) for v in init_values]
|
||||
axis_env = xla.AxisEnv(1) # no parallel primitives inside reductions
|
||||
subc = xla_bridge.make_computation_builder("variadic_reduction_computation")
|
||||
assert len(consts) == 0, "Reduction computations cannot have constants"
|
||||
args = [xb.parameter(subc, 2 * i + j, shape)
|
||||
for i, shape in enumerate(shapes) for j in range(2)]
|
||||
out = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
|
||||
xla_computation = subc.build(xops.Tuple(subc, out))
|
||||
return xops.Reduce(c, operands, init_values, xla_computation, dimensions)
|
||||
def _reduce_translation_rule(c, operand, init_value, *, computation, jaxpr,
|
||||
consts, dimensions):
|
||||
xla_computation = _reduction_computation(c, jaxpr, consts, init_value)
|
||||
return xops.Reduce(c, [operand], [init_value], xla_computation, dimensions)
|
||||
|
||||
def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, consts,
|
||||
dimensions):
|
||||
if len(batched_args) != 2:
|
||||
# TODO(jakevdp): implement this after generalizing reduce implementation.
|
||||
raise NotImplementedError("reduce batch rule for more than one array.")
|
||||
operand, init_value = batched_args
|
||||
operand_bdim, init_value_bdim = batch_dims
|
||||
if init_value_bdim is not None:
|
||||
# TODO(jakevdp): implement this via loop and stack.
|
||||
raise NotImplementedError("batched reduce with different init_val per batch")
|
||||
assert operand_bdim is not None
|
||||
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
|
||||
new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim)))
|
||||
out = reduce(operand, init_value, computation, new_dimensions)
|
||||
return (out,), (new_operand_bdim,)
|
||||
if init_value_bdim is None:
|
||||
assert operand_bdim is not None
|
||||
new_dimensions = [d + bool(d >= operand_bdim) for d in dimensions]
|
||||
new_operand_bdim = operand_bdim - int(onp.sum(onp.less(dimensions, operand_bdim)))
|
||||
return reduce(operand, init_value, computation, new_dimensions), new_operand_bdim
|
||||
else:
|
||||
raise NotImplementedError # loop and stack
|
||||
|
||||
def _reduction_computation(c, jaxpr, consts, init_value):
|
||||
shape = c.get_shape(init_value)
|
||||
@ -4165,12 +4122,8 @@ def _reducer_masking_rule(prim, identity, padded_vals, logical_shapes,
|
||||
bind = prim.bind if input_shape is None else partial(prim.bind, input_shape=padded_shape)
|
||||
return bind(masked_val, axes=axes)
|
||||
|
||||
|
||||
reduce_p = Primitive('reduce')
|
||||
reduce_p.multiple_results = True
|
||||
reduce_p.def_impl(partial(xla.apply_primitive, reduce_p))
|
||||
reduce_p.def_abstract_eval(_reduce_abstract_eval)
|
||||
xla.translations[reduce_p] = _reduce_translation_rule
|
||||
reduce_p = standard_primitive(_reduce_shape_rule, _input_dtype, 'reduce',
|
||||
_reduce_translation_rule)
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
|
||||
|
@ -32,7 +32,6 @@ from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax import lax_reference
|
||||
from jax.test_util import check_grads
|
||||
from jax.lax.lax import _get_min_identity, _get_max_identity
|
||||
from jax.lib import xla_client
|
||||
import jax.util
|
||||
|
||||
@ -1222,29 +1221,6 @@ class LaxTest(jtu.JaxTestCase):
|
||||
numpy_op = lambda x: lax_reference.transpose(x, perm)
|
||||
self._CheckAgainstNumpy(op, numpy_op, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_reducedims={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
|
||||
"shape": shape, "dtype": dtype, "dims": dims}
|
||||
for dtype in default_dtypes
|
||||
for shape, dims in [
|
||||
[(3, 4, 5), (0,)], [(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)], [(3, 4, 5), (0, 1, 2)]
|
||||
]))
|
||||
def testMultiReduce(self, shape, dtype, dims):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
op = lambda a1, b1, a2, b2: (lax.min(a1, a2), lax.max(b1, b2))
|
||||
|
||||
np_fun = lambda a, b: (a.min(axis=dims), b.max(axis=dims))
|
||||
def jnp_fun(a, b):
|
||||
# device_put here to ensure dtype below is correct.
|
||||
a, b = jax.device_put(a), jax.device_put(b)
|
||||
init_val = (_get_min_identity(a.dtype), _get_max_identity(b.dtype))
|
||||
return lax.reduce((a, b), init_val, op, dims)
|
||||
args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)]
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_op={}_inshape={}_reducedims={}_initval={}"
|
||||
.format(op.__name__, jtu.format_shape_dtype_string(shape, dtype), dims,
|
||||
|
Loading…
x
Reference in New Issue
Block a user