mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
add lax._safe_mul with 0*inf=0, used in pow jvp
This commit is contained in:
parent
1cbf49a404
commit
58749c0a13
26
jax/lax.py
26
jax/lax.py
@ -493,6 +493,9 @@ def stop_gradient(x):
|
||||
return stop_gradient_p.bind(x)
|
||||
|
||||
|
||||
def _safe_mul(x, y): return safe_mul_p.bind(x, y)
|
||||
|
||||
|
||||
def psum(x, axis_name):
|
||||
return psum_p.bind(x, axis_name=axis_name)
|
||||
|
||||
@ -959,13 +962,8 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
def pow_jvp_lhs(g, x, y):
|
||||
exponent = select(eq(y, _zero(y)), _ones(y), sub(y, _one(y)))
|
||||
x_pow_ym1 = pow(x, exponent) # x ** (y-1), except where x==0 or y==0
|
||||
x_pow_ym1 = select(_brcast(eq(x, _zero(y)), x_pow_ym1), # pow(0, a) is 0
|
||||
_zeros(x_pow_ym1), x_pow_ym1) # unless a == 0
|
||||
x_pow_ym1 = select(_brcast(eq(y, _zero(y)), x_pow_ym1), # pow(a, 0) is 0
|
||||
_ones(x_pow_ym1), x_pow_ym1)
|
||||
return mul(_brcast(g, y), mul(y, x_pow_ym1))
|
||||
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
|
||||
return _safe_mul(_brcast(g, y), jac)
|
||||
|
||||
def pow_jvp_rhs(g, x, y):
|
||||
return mul(_brcast(g, x), mul(log(_replace_zero(x)), pow(x, y)))
|
||||
@ -1001,6 +999,20 @@ mul_p = standard_binop([_num, _num], 'mul')
|
||||
ad.defbilinear_broadcasting(_brcast, mul_p, mul, mul) # TODO
|
||||
|
||||
|
||||
def _safe_mul_translation_rule(c, x, y):
|
||||
dtype = c.GetShape(x).numpy_dtype()
|
||||
zero = c.Constant(onp.array(0, dtype=dtype))
|
||||
out_shape = tuple(onp.maximum(c.GetShape(x).dimensions(),
|
||||
c.GetShape(y).dimensions()))
|
||||
return c.Select(c.Or(c.Eq(x, zero), c.Eq(y, zero)),
|
||||
c.Broadcast(zero, out_shape),
|
||||
c.Mul(x, y))
|
||||
|
||||
safe_mul_p = standard_binop([_num, _num], 'safe_mul',
|
||||
translation_rule=_safe_mul_translation_rule)
|
||||
ad.defbilinear_broadcasting(_brcast, safe_mul_p, _safe_mul, _safe_mul)
|
||||
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert x is None and y is not None
|
||||
res = ad_util.zero if cotangent is ad_util.zero else div(cotangent, y)
|
||||
|
@ -348,7 +348,7 @@ class _JaxComputationBuilderBase(object):
|
||||
|
||||
def ConstantLike(self, example_value, value, canonicalize_types=True):
|
||||
example_value = onp.asarray(example_value)
|
||||
return self.Constant(onp.array(value).astype(example_value.dtype))
|
||||
return self.Constant(onp.array(value, dtype=example_value.dtype))
|
||||
|
||||
def Constant(self, py_val, canonicalize_types=True):
|
||||
"""Translate constant `py_val` to a constant for this ComputationBuilder.
|
||||
|
Loading…
x
Reference in New Issue
Block a user