Add decorator for performing broadcasting inside translation rules (#2468)

* Add decorator for broadcasting at the translation rule layer.

* Fix broadcasting in igamma gradients.

Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
notEvil 2020-05-06 16:15:17 +02:00 committed by GitHub
parent aedf346c8e
commit 969ed8085c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 13 deletions

View File

@ -148,7 +148,7 @@ def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
if not core.skip_checks:
ct_aval = core.get_aval(ct_env[v])
assert v.aval == core.lattice_join(v.aval, ct_aval)
assert v.aval == core.lattice_join(v.aval, ct_aval), (v.aval, ct_aval)
def read_cotangent(v):
return ct_env.get(v, zero)

View File

@ -175,9 +175,6 @@ def atan2(x: Array, y: Array) -> Array:
def betainc(a: Array, b: Array, x: Array) -> Array:
r"""Elementwise regularized incomplete beta integral."""
a = _brcast(_brcast(a, b), x)
b = _brcast(b, a)
x = _brcast(x, a)
return regularized_incomplete_beta_p.bind(a, b, x)
def lgamma(x: Array) -> Array:
@ -190,15 +187,15 @@ def digamma(x: Array) -> Array:
def igamma(a: Array, x: Array) -> Array:
r"""Elementwise regularized incomplete gamma function."""
return igamma_p.bind(_brcast(a, x), _brcast(x, a))
return igamma_p.bind(a, x)
def igammac(a: Array, x: Array) -> Array:
r"""Elementwise complementary regularized incomplete gamma function."""
return igammac_p.bind(_brcast(a, x), _brcast(x, a))
return igammac_p.bind(a, x)
def igamma_grad_a(a: Array, x: Array) -> Array:
r"""Elementwise derivative of the regularized incomplete gamma function."""
return igamma_grad_a_p.bind(_brcast(a, x), _brcast(x, a))
return igamma_grad_a_p.bind(a, x)
def bessel_i0e(x: Array) -> Array:
r"""Exponentially scaled modified Bessel function of order 0:
@ -1784,6 +1781,26 @@ def naryop(result_dtype, accepted_dtypes, name, translation_rule=None):
standard_naryop = partial(naryop, _input_dtype)
def _broadcast_translate(translate: Callable):
# Decorator for translation rules which adds explicit broadcasting of
# positional arguments. This is necessary only for a handful of primitives
# whose XLA implementations do not support broadcasting.
def _broadcast_array(array, array_shape, result_shape):
if array_shape == result_shape:
return array
bcast_dims = tuple(range(len(result_shape) - len(array_shape),
len(result_shape)))
result = xops.BroadcastInDim(array, result_shape, bcast_dims)
return result
def _broadcasted_translation_rule(c, *args, **kwargs):
shapes = [c.GetShape(arg).dimensions() for arg in args]
result_shape = broadcast_shapes(*shapes)
args = [_broadcast_array(arg, arg_shape, result_shape)
for arg, arg_shape in zip(args, shapes)]
return translate(c, *args, **kwargs)
return _broadcasted_translation_rule
# NOTE(mattjj): this isn't great for orchestrate fwd mode because it means JVPs
# get two extra ops in them: a reshape and a broadcast_in_dim (or sometimes just
# a broadcast). but saving the shape info with the primitives isn't great either
@ -1901,7 +1918,9 @@ ad.defjvp(atanh_p,
lambda g, x: mul(g, reciprocal((_one(x) - x) * (_one(x) + x))))
regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')
[_float, _float, _float], 'regularized_incomplete_beta',
translation_rule=_broadcast_translate(
partial(standard_translate, 'regularized_incomplete_beta')))
def betainc_gradx(g, a, b, x):
lbeta = lgamma(a) + lgamma(b) - lgamma(a + b)
@ -1922,18 +1941,24 @@ ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
digamma_p = standard_unop(_float, 'digamma')
igamma_p = standard_naryop([_float, _float], 'igamma')
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a')
igamma_p = standard_naryop(
[_float, _float], 'igamma',
translation_rule=_broadcast_translate(partial(standard_translate, 'igamma')))
igamma_grad_a_p = standard_naryop([_float, _float], 'igamma_grad_a',
translation_rule=_broadcast_translate(partial(standard_translate,
'igamma_grad_a')))
def igamma_gradx(g, a, x):
return g * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
return _brcast(g, a, x) * exp(-x + (a - _ones(a)) * log(x) - lgamma(a))
def igamma_grada(g, a, x):
return g * igamma_grad_a(a, x)
return _brcast(g, a, x) * igamma_grad_a(a, x)
ad.defjvp(igamma_p, igamma_grada, igamma_gradx)
igammac_p = standard_naryop([_float, _float], 'igammac')
igammac_p = standard_naryop(
[_float, _float], 'igammac',
translation_rule=_broadcast_translate(partial(standard_translate, 'igammac')))
def igammac_gradx(g, a, x):
return -igamma_gradx(g, a, x)