mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add decorator for performing broadcasting inside translation rules (#2468)
* Add decorator for broadcasting at the translation rule layer. * Fix broadcasting in igamma gradients. Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
parent
aedf346c8e
commit
969ed8085c
@ -148,7 +148,7 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
|
||||
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
|
||||
if not core.skip_checks:
|
||||
ct_aval = core.get_aval(ct_env[v])
|
||||
assert v.aval == core.lattice_join(v.aval, ct_aval)
|
||||
assert v.aval == core.lattice_join(v.aval, ct_aval), (v.aval, ct_aval)
|
||||
|
||||
def read_cotangent(v):
|
||||
return ct_env.get(v, zero)
|
||||
|
@ -175,9 +175,6 @@ def atan2(x: Array, y: Array) -> Array:
|
||||
|
||||
def betainc(a: Array, b: Array, x: Array) -> Array:
|
||||
r"""Elementwise regularized incomplete beta integral."""
|
||||
a = _brcast(_brcast(a, b), x)
|
||||
b = _brcast(b, a)
|
||||
x = _brcast(x, a)
|
||||
return regularized_incomplete_beta_p.bind(a, b, x)
|
||||
|
||||
def lgamma(x: Array) -> Array:
|
||||
@ -190,15 +187,15 @@ def digamma(x: Array) -> Array:
|
||||
|
||||
def igamma(a: Array, x: Array) -> Array:
|
||||
r"""Elementwise regularized incomplete gamma function."""
|
||||
return igamma_p.bind(_brcast(a, x), _brcast(x, a))
|
||||
return igamma_p.bind(a, x)
|
||||
|
||||
def igammac(a: Array, x: Array) -> Array:
|
||||
r"""Elementwise complementary regularized incomplete gamma function."""
|
||||
return igammac_p.bind(_brcast(a, x), _brcast(x, a))
|
||||
return igammac_p.bind(a, x)
|
||||
|
||||
def igamma_grad_a(a: Array, x: Array) -> Array:
|
||||
r"""Elementwise derivative of the regularized incomplete gamma function."""
|
||||
return igamma_grad_a_p.bind(_brcast(a, x), _brcast(x, a))
|
||||
return igamma_grad_a_p.bind(a, x)
|
||||
|
||||
def bessel_i0e(x: Array) -> Array:
|
||||
r"""Exponentially scaled modified Bessel function of order 0:
|
||||
@ -1784,6 +1781,26 @@ def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
|
||||
standard_naryop = partial(naryop, _input_dtype)
|
||||
|
||||
|
||||
def _broadcast_translate(translate: Callable):
|
||||
# Decorator for translation rules which adds explicit broadcasting of
|
||||
# positional arguments. This is necessary only for a handful of primitives
|
||||
# whose XLA implementations do not support broadcasting.
|
||||
def _broadcast_array(array, array_shape, result_shape):
|
||||
if array_shape == result_shape:
|
||||
return array
|
||||
bcast_dims = tuple(range(len(result_shape) - len(array_shape),
|
||||
len(result_shape)))
|
||||
result = xops.BroadcastInDim(array, result_shape, bcast_dims)
|
||||
return result
|
||||
|
||||
def _broadcasted_translation_rule(c, *args, **kwargs):
|
||||
shapes = [c.GetShape(arg).dimensions() for arg in args]
|
||||
result_shape = broadcast_shapes(*shapes)
|
||||
args = [_broadcast_array(arg, arg_shape, result_shape)
|
||||
for arg, arg_shape in zip(args, shapes)]
|
||||
return translate(c, *args, **kwargs)
|
||||
return _broadcasted_translation_rule
|
||||
|
||||
# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
|
||||
# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
|
||||
# a broadcast). but saving the shape info with the primitives isn't great either
|
||||
@ -1901,7 +1918,9 @@ ad.defjvp(atanh_p,
|
||||
lambda g, x: mul(g, reciprocal((_one(x) - x) * (_one(x) + x))))
|
||||
|
||||
regularized_incomplete_beta_p = standard_naryop(
|
||||
[_float, _float, _float], 'regularized_incomplete_beta')
|
||||
[_float, _float, _float], 'regularized_incomplete_beta',
|
||||
translation_rule=_broadcast_translate(
|
||||
partial(standard_translate, 'regularized_incomplete_beta')))
|
||||
|
||||
def betainc_gradx(g, a, b, x):
|
||||
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
|
||||
@ -1922,18 +1941,24 @@ ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
|
||||
digamma_p = standard_unop(_float, 'digamma')
|
||||
|
||||
igamma_p = standard_naryop([_float, _float], 'igamma')
|
||||
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
|
||||
igamma_p = standard_naryop(
|
||||
[_float, _float], 'igamma',
|
||||
translation_rule=_broadcast_translate(partial(standard_translate, 'igamma')))
|
||||
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a',
|
||||
translation_rule=_broadcast_translate(partial(standard_translate,
|
||||
'igamma_grad_a')))
|
||||
|
||||
def igamma_gradx(g, a, x):
|
||||
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
|
||||
return _brcast(g, a, x) * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
|
||||
|
||||
def igamma_grada(g, a, x):
|
||||
return g * igamma_grad_a(a, x)
|
||||
return _brcast(g, a, x) * igamma_grad_a(a, x)
|
||||
|
||||
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
|
||||
|
||||
igammac_p = standard_naryop([_float, _float], 'igammac')
|
||||
igammac_p = standard_naryop(
|
||||
[_float, _float], 'igammac',
|
||||
translation_rule=_broadcast_translate(partial(standard_translate, 'igammac')))
|
||||
|
||||
def igammac_gradx(g, a, x):
|
||||
return -igamma_gradx(g, a, x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user