mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add default jvp and transpose rule for jax.lax.reduce_precision.
PiperOrigin-RevId: 564536160
This commit is contained in:
parent
997b35e1d9
commit
d4adf0095f
@ -3991,6 +3991,7 @@ reduce_precision_p = standard_primitive(
|
||||
_reduce_precision_shape_rule,
|
||||
partial(unop_dtype_rule, _identity, _float, 'reduce_precision'),
|
||||
name='reduce_precision')
|
||||
ad.deflinear(reduce_precision_p, lambda t, **kwargs: [reduce_precision_p.bind(t, **kwargs)])
|
||||
batching.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
|
@ -1978,6 +1978,14 @@ class LaxTest(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, fun, args_maker)
|
||||
self._CompileAndCheck(fun, args_maker)
|
||||
|
||||
def testReducePrecisionGrad(self):
|
||||
info = dtypes.finfo(jnp.dtype('bfloat16'))
|
||||
y, f_vjp = jax.vjp(lambda x: lax.reduce_precision(x, info.nexp, info.nmant), jnp.pi)
|
||||
y2 = f_vjp(jnp.pi)
|
||||
y3 = lax.reduce_precision(jnp.pi, info.nexp, info.nmant)
|
||||
self.assertArraysEqual(y, y2)
|
||||
self.assertArraysEqual(y, y3)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shape=shape, axis=axis)
|
||||
for shape in [(5,), (5, 7)] for axis in [-1, len(shape) - 1]],
|
||||
|
Loading…
x
Reference in New Issue
Block a user