mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add a JVP rule for the general case of lax.reduce.
This commit is contained in:
parent
634397dc59
commit
3fc1fdb148
@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* New features
|
||||
* New profiling APIs: {func}`jax.profiler.start_trace`,
|
||||
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
|
||||
* {func}`jax.lax.reduce` is now differentiable.
|
||||
* Breaking changes:
|
||||
* The minimum jaxlib version is now 0.1.64.
|
||||
* Some profiler APIs names have been changed. There are still aliases, so this
|
||||
@ -41,6 +42,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
|
||||
`jax.debug_infs`, `jax.log_compiles`.
|
||||
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
|
||||
|
||||
* Bug fixes:
|
||||
* [#6136](https://github.com/google/jax/pull/6136) generalized
|
||||
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
|
||||
|
@ -4993,6 +4993,52 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
|
||||
out_nodes = xops.Tuple(subc, out_nodes)
|
||||
return subc.build(out_nodes)
|
||||
|
||||
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
||||
input_shape = np.array(primals[0].shape)
|
||||
|
||||
n = np.prod(input_shape[list(axes)])
|
||||
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
||||
|
||||
# Move the reduced axes to the front, and flatten them to 1D.
|
||||
permutation = axes + tuple(non_axes)
|
||||
new_shape = (n,) + tuple(input_shape[non_axes])
|
||||
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
|
||||
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
|
||||
|
||||
for d in range(len(non_axes) + 1):
|
||||
reducer = api.vmap(reducer)
|
||||
def _reduce_tree(*xs, axis=0):
|
||||
"""Reduce by repeatedly splitting the array and multiplying."""
|
||||
while xs[0].shape[axis] > 1:
|
||||
n = xs[0].shape[axis]
|
||||
n1 = (n + 1) // 2
|
||||
n2 = n - n1
|
||||
xs1 = [slice_in_dim(x, 0, n1) for x in xs]
|
||||
xs2 = [slice_in_dim(x, n1, None) for x in xs]
|
||||
if n2 != n1:
|
||||
paddings = [(0, 0, 0)] * len(xs[0].shape)
|
||||
paddings[axis] = (0, 1, 0)
|
||||
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
|
||||
xs = reducer(*(xs1 + xs2))
|
||||
if xs[0].shape[axis] == 0:
|
||||
return [full(input_shape[non_axes], i) for i in init_values]
|
||||
return tuple(squeeze(x, (axis,)) for x in xs)
|
||||
|
||||
return api.jvp(_reduce_tree, primals, tangents)
|
||||
|
||||
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr,
|
||||
consts, dimensions):
|
||||
primal_xs, init_values = split_list(primals, [len(primals) // 2])
|
||||
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
|
||||
# This test may be too strict, if a value is actually zero but we cannot prove
|
||||
# it is symbolically zero.
|
||||
if any(type(t) is not ad_util.Zero for t in tangent_init):
|
||||
raise NotImplementedError(
|
||||
"Gradient of general lax.reduce with non-zero tangents for "
|
||||
"initial values to reduction not implemented")
|
||||
reducer = core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))
|
||||
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
||||
|
||||
def _masking_defreducer(prim, identity):
|
||||
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
||||
|
||||
@ -5030,7 +5076,7 @@ reduce_p.def_abstract_eval(
|
||||
_reduce_named_shape_rule))
|
||||
xla.translations[reduce_p] = _reduce_translation_rule
|
||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||
|
||||
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
||||
|
||||
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
||||
if not dtypes.issubdtype(operand.dtype, np.number):
|
||||
@ -5081,38 +5127,10 @@ def _reduce_prod_translation_rule(c, operand, *, axes):
|
||||
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
|
||||
|
||||
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||
operand, = primals
|
||||
tangent, = tangents
|
||||
input_shape = np.array(operand.shape)
|
||||
|
||||
n = np.prod(input_shape[list(axes)])
|
||||
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
||||
|
||||
# Move the reduced axes to the front, and flatten them to 1D.
|
||||
permutation = axes + tuple(non_axes)
|
||||
new_shape = (n,) + tuple(input_shape[non_axes])
|
||||
operand = reshape(operand, new_shape, permutation)
|
||||
tangent = reshape(tangent, new_shape, permutation)
|
||||
|
||||
def _reduce_prod_tree(x, axis=0):
|
||||
"""Reduce by repeatedly splitting the array and multiplying."""
|
||||
while x.shape[axis] > 1:
|
||||
n = x.shape[axis]
|
||||
n1 = (n + 1) // 2
|
||||
n2 = n - n1
|
||||
x1 = slice_in_dim(x, 0, n1)
|
||||
x2 = slice_in_dim(x, n1, None)
|
||||
if n2 != n1:
|
||||
paddings = [(0, 0, 0)] * len(x.shape)
|
||||
paddings[axis] = (0, 1, 0)
|
||||
x2 = pad(x2, _const(x, 1), paddings)
|
||||
x = x1 * x2
|
||||
if x.shape[axis] == 0:
|
||||
return full(input_shape[non_axes], _one(x))
|
||||
return squeeze(x, (axis,))
|
||||
|
||||
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
|
||||
|
||||
reducer = lambda x, y: [mul(x, y)]
|
||||
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
||||
primals, tangents, axes)
|
||||
return primals_out[0], tangents_out[0]
|
||||
|
||||
reduce_prod_p = standard_primitive(
|
||||
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
||||
|
@ -654,6 +654,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
|
||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_inshape={}_reducedims={}"
|
||||
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
|
||||
"shape": shape, "dtype": dtype, "dims": dims}
|
||||
for dtype in grad_float_dtypes
|
||||
for shape, dims in [
|
||||
[(3, 4, 5), ()],
|
||||
[(3, 4, 5), (0,)],
|
||||
[(3, 4, 5), (1, 2)],
|
||||
[(3, 4, 5), (0, 2)],
|
||||
[(3, 4, 5), (0, 1, 2)],
|
||||
[(3, 1), (1,)],
|
||||
[(3, 0, 5), (1,)],
|
||||
]))
|
||||
def testReducePairGrad(self, shape, dtype, dims):
|
||||
rng = jtu.rand_default(self.rng(), scale=1)
|
||||
tol = {np.float32: 1e-2, np.float64: 1e-4}
|
||||
operands = (rng(shape, dtype), rng(shape, dtype))
|
||||
init_vals = (np.array(0, dtype), np.array(1, dtype))
|
||||
def op(xs, ys):
|
||||
return (xs[0] + ys[0], xs[1] * ys[1])
|
||||
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
|
||||
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||
"_basedilation={}_windowdilation={}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user