Add default jvp and transpose rule for jax.lax.reduce_precision.

PiperOrigin-RevId: 564536160
This commit is contained in:
Qiao Zhang 2023-09-11 16:35:00 -07:00 committed by jax authors
parent 997b35e1d9
commit d4adf0095f
2 changed files with 9 additions and 0 deletions

View File

@ -3991,6 +3991,7 @@ reduce_precision_p = standard_primitive(
_reduce_precision_shape_rule,
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
name='reduce_precision')
ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)])
batching.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):

View File

@ -1978,6 +1978,14 @@ class LaxTest(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, fun, args_maker)
self._CompileAndCheck(fun, args_maker)
def testReducePrecisionGrad(self):
info = dtypes.finfo(jnp.dtype('bfloat16'))
y, f_vjp = jax.vjp(lambda x: lax.reduce_precision(x, info.nexp, info.nmant), jnp.pi)
y2 = f_vjp(jnp.pi)
y3 = lax.reduce_precision(jnp.pi, info.nexp, info.nmant)
self.assertArraysEqual(y, y2)
self.assertArraysEqual(y, y3)
@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in [(5,), (5, 7)] for axis in [-1, len(shape) - 1]],