mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a JVP rule for the general case of lax.reduce.
This commit is contained in:
parent
634397dc59
commit
3fc1fdb148
@ -13,6 +13,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
* New features
|
* New features
|
||||||
* New profiling APIs: {func}`jax.profiler.start_trace`,
|
* New profiling APIs: {func}`jax.profiler.start_trace`,
|
||||||
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
|
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
|
||||||
|
* {func}`jax.lax.reduce` is now differentiable.
|
||||||
* Breaking changes:
|
* Breaking changes:
|
||||||
* The minimum jaxlib version is now 0.1.64.
|
* The minimum jaxlib version is now 0.1.64.
|
||||||
* Some profiler APIs names have been changed. There are still aliases, so this
|
* Some profiler APIs names have been changed. There are still aliases, so this
|
||||||
@ -41,6 +42,7 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
|
|||||||
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
|
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
|
||||||
`jax.debug_infs`, `jax.log_compiles`.
|
`jax.debug_infs`, `jax.log_compiles`.
|
||||||
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
|
* [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete`
|
||||||
|
|
||||||
* Bug fixes:
|
* Bug fixes:
|
||||||
* [#6136](https://github.com/google/jax/pull/6136) generalized
|
* [#6136](https://github.com/google/jax/pull/6136) generalized
|
||||||
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
|
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
|
||||||
|
@ -4993,6 +4993,52 @@ def _reduction_computation(c, jaxpr, consts, init_values, singleton=True):
|
|||||||
out_nodes = xops.Tuple(subc, out_nodes)
|
out_nodes = xops.Tuple(subc, out_nodes)
|
||||||
return subc.build(out_nodes)
|
return subc.build(out_nodes)
|
||||||
|
|
||||||
|
def _reduce_jvp(reducer, init_values, primals, tangents, axes):
|
||||||
|
input_shape = np.array(primals[0].shape)
|
||||||
|
|
||||||
|
n = np.prod(input_shape[list(axes)])
|
||||||
|
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
||||||
|
|
||||||
|
# Move the reduced axes to the front, and flatten them to 1D.
|
||||||
|
permutation = axes + tuple(non_axes)
|
||||||
|
new_shape = (n,) + tuple(input_shape[non_axes])
|
||||||
|
primals = tuple(reshape(x, new_shape, permutation) for x in primals)
|
||||||
|
tangents = tuple(reshape(t, new_shape, permutation) for t in tangents)
|
||||||
|
|
||||||
|
for d in range(len(non_axes) + 1):
|
||||||
|
reducer = api.vmap(reducer)
|
||||||
|
def _reduce_tree(*xs, axis=0):
|
||||||
|
"""Reduce by repeatedly splitting the array and multiplying."""
|
||||||
|
while xs[0].shape[axis] > 1:
|
||||||
|
n = xs[0].shape[axis]
|
||||||
|
n1 = (n + 1) // 2
|
||||||
|
n2 = n - n1
|
||||||
|
xs1 = [slice_in_dim(x, 0, n1) for x in xs]
|
||||||
|
xs2 = [slice_in_dim(x, n1, None) for x in xs]
|
||||||
|
if n2 != n1:
|
||||||
|
paddings = [(0, 0, 0)] * len(xs[0].shape)
|
||||||
|
paddings[axis] = (0, 1, 0)
|
||||||
|
xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)]
|
||||||
|
xs = reducer(*(xs1 + xs2))
|
||||||
|
if xs[0].shape[axis] == 0:
|
||||||
|
return [full(input_shape[non_axes], i) for i in init_values]
|
||||||
|
return tuple(squeeze(x, (axis,)) for x in xs)
|
||||||
|
|
||||||
|
return api.jvp(_reduce_tree, primals, tangents)
|
||||||
|
|
||||||
|
def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr,
|
||||||
|
consts, dimensions):
|
||||||
|
primal_xs, init_values = split_list(primals, [len(primals) // 2])
|
||||||
|
tangent_xs, tangent_init = split_list(tangents, [len(tangents) // 2])
|
||||||
|
# This test may be too strict, if a value is actually zero but we cannot prove
|
||||||
|
# it is symbolically zero.
|
||||||
|
if any(type(t) is not ad_util.Zero for t in tangent_init):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Gradient of general lax.reduce with non-zero tangents for "
|
||||||
|
"initial values to reduction not implemented")
|
||||||
|
reducer = core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))
|
||||||
|
return _reduce_jvp(reducer, init_values, primal_xs, tangent_xs, dimensions)
|
||||||
|
|
||||||
def _masking_defreducer(prim, identity):
|
def _masking_defreducer(prim, identity):
|
||||||
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
masking.masking_rules[prim] = partial(_reducer_masking_rule, prim, identity)
|
||||||
|
|
||||||
@ -5030,7 +5076,7 @@ reduce_p.def_abstract_eval(
|
|||||||
_reduce_named_shape_rule))
|
_reduce_named_shape_rule))
|
||||||
xla.translations[reduce_p] = _reduce_translation_rule
|
xla.translations[reduce_p] = _reduce_translation_rule
|
||||||
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
|
||||||
|
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
|
||||||
|
|
||||||
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
def _reduce_number_dtype_rule(name, operand, *args, **kw):
|
||||||
if not dtypes.issubdtype(operand.dtype, np.number):
|
if not dtypes.issubdtype(operand.dtype, np.number):
|
||||||
@ -5081,38 +5127,10 @@ def _reduce_prod_translation_rule(c, operand, *, axes):
|
|||||||
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
|
xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
|
||||||
|
|
||||||
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
def _reduce_prod_jvp_rule(primals, tangents, *, axes):
|
||||||
operand, = primals
|
reducer = lambda x, y: [mul(x, y)]
|
||||||
tangent, = tangents
|
primals_out, tangents_out = _reduce_jvp(reducer, [_const(primals[0], 1)],
|
||||||
input_shape = np.array(operand.shape)
|
primals, tangents, axes)
|
||||||
|
return primals_out[0], tangents_out[0]
|
||||||
n = np.prod(input_shape[list(axes)])
|
|
||||||
non_axes = np.delete(np.arange(len(input_shape)), axes)
|
|
||||||
|
|
||||||
# Move the reduced axes to the front, and flatten them to 1D.
|
|
||||||
permutation = axes + tuple(non_axes)
|
|
||||||
new_shape = (n,) + tuple(input_shape[non_axes])
|
|
||||||
operand = reshape(operand, new_shape, permutation)
|
|
||||||
tangent = reshape(tangent, new_shape, permutation)
|
|
||||||
|
|
||||||
def _reduce_prod_tree(x, axis=0):
|
|
||||||
"""Reduce by repeatedly splitting the array and multiplying."""
|
|
||||||
while x.shape[axis] > 1:
|
|
||||||
n = x.shape[axis]
|
|
||||||
n1 = (n + 1) // 2
|
|
||||||
n2 = n - n1
|
|
||||||
x1 = slice_in_dim(x, 0, n1)
|
|
||||||
x2 = slice_in_dim(x, n1, None)
|
|
||||||
if n2 != n1:
|
|
||||||
paddings = [(0, 0, 0)] * len(x.shape)
|
|
||||||
paddings[axis] = (0, 1, 0)
|
|
||||||
x2 = pad(x2, _const(x, 1), paddings)
|
|
||||||
x = x1 * x2
|
|
||||||
if x.shape[axis] == 0:
|
|
||||||
return full(input_shape[non_axes], _one(x))
|
|
||||||
return squeeze(x, (axis,))
|
|
||||||
|
|
||||||
return api.jvp(_reduce_prod_tree, (operand,), (tangent,))
|
|
||||||
|
|
||||||
|
|
||||||
reduce_prod_p = standard_primitive(
|
reduce_prod_p = standard_primitive(
|
||||||
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
||||||
|
@ -654,6 +654,30 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
|||||||
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
|
if op not in (lax.max, lax.min) or all(d > 0 for d in shape):
|
||||||
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
check_grads(reduce, (operand,), 2, ["fwd", "rev"], tol, tol, eps)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
|
{"testcase_name": "_inshape={}_reducedims={}"
|
||||||
|
.format(jtu.format_shape_dtype_string(shape, dtype), dims),
|
||||||
|
"shape": shape, "dtype": dtype, "dims": dims}
|
||||||
|
for dtype in grad_float_dtypes
|
||||||
|
for shape, dims in [
|
||||||
|
[(3, 4, 5), ()],
|
||||||
|
[(3, 4, 5), (0,)],
|
||||||
|
[(3, 4, 5), (1, 2)],
|
||||||
|
[(3, 4, 5), (0, 2)],
|
||||||
|
[(3, 4, 5), (0, 1, 2)],
|
||||||
|
[(3, 1), (1,)],
|
||||||
|
[(3, 0, 5), (1,)],
|
||||||
|
]))
|
||||||
|
def testReducePairGrad(self, shape, dtype, dims):
|
||||||
|
rng = jtu.rand_default(self.rng(), scale=1)
|
||||||
|
tol = {np.float32: 1e-2, np.float64: 1e-4}
|
||||||
|
operands = (rng(shape, dtype), rng(shape, dtype))
|
||||||
|
init_vals = (np.array(0, dtype), np.array(1, dtype))
|
||||||
|
def op(xs, ys):
|
||||||
|
return (xs[0] + ys[0], xs[1] * ys[1])
|
||||||
|
reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims)
|
||||||
|
check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol)
|
||||||
|
|
||||||
@parameterized.named_parameters(jtu.cases_from_list(
|
@parameterized.named_parameters(jtu.cases_from_list(
|
||||||
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
{"testcase_name": ("_op={}_shape={}_dims={}_strides={}_padding={}"
|
||||||
"_basedilation={}_windowdilation={}")
|
"_basedilation={}_windowdilation={}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user