add lax._safe_mul with 0*inf=0, used in pow jvp

This commit is contained in:
Matthew Johnson 2019-02-15 18:32:50 -08:00
parent 1cbf49a404
commit 58749c0a13
2 changed files with 20 additions and 8 deletions

View File

@ -493,6 +493,9 @@ def stop_gradient(x):
return stop_gradient_p.bind(x)
def _safe_mul(x, y): return safe_mul_p.bind(x, y)
def psum(x, axis_name):
return psum_p.bind(x, axis_name=axis_name)
@ -959,13 +962,8 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
def pow_jvp_lhs(g, x, y):
exponent = select(eq(y, _zero(y)), _ones(y), sub(y, _one(y)))
x_pow_ym1 = pow(x, exponent) # x ** (y-1), except where x==0 or y==0
x_pow_ym1 = select(_brcast(eq(x, _zero(y)), x_pow_ym1), # pow(0, a) is 0
_zeros(x_pow_ym1), x_pow_ym1) # unless a == 0
x_pow_ym1 = select(_brcast(eq(y, _zero(y)), x_pow_ym1), # pow(a, 0) is 0
_ones(x_pow_ym1), x_pow_ym1)
return mul(_brcast(g, y), mul(y, x_pow_ym1))
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return _safe_mul(_brcast(g, y), jac)
def pow_jvp_rhs(g, x, y):
return mul(_brcast(g, x), mul(log(_replace_zero(x)), pow(x, y)))
@ -1001,6 +999,20 @@ mul_p = standard_binop([_num, _num], 'mul')
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
def _safe_mul_translation_rule(c, x, y):
dtype = c.GetShape(x).numpy_dtype()
zero = c.Constant(onp.array(0, dtype=dtype))
out_shape = tuple(onp.maximum(c.GetShape(x).dimensions(),
c.GetShape(y).dimensions()))
return c.Select(c.Or(c.Eq(x, zero), c.Eq(y, zero)),
c.Broadcast(zero, out_shape),
c.Mul(x, y))
safe_mul_p = standard_binop([_num, _num], 'safe_mul',
translation_rule=_safe_mul_translation_rule)
ad.defbilinear_broadcasting(_brcast, safe_mul_p, _safe_mul, _safe_mul)
def _div_transpose_rule(cotangent, x, y):
assert x is None and y is not None
res = ad_util.zero if cotangent is ad_util.zero else div(cotangent, y)

View File

@ -348,7 +348,7 @@ class _JaxComputationBuilderBase(object):
def ConstantLike(self, example_value, value, canonicalize_types=True):
example_value = onp.asarray(example_value)
return self.Constant(onp.array(value).astype(example_value.dtype))
return self.Constant(onp.array(value, dtype=example_value.dtype))
def Constant(self, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant for this ComputationBuilder.