jnp.power: fix overflow case for x1=0

This commit is contained in:
Jake VanderPlas 2021-03-09 09:36:41 -08:00
parent 0b88b0ea9b
commit 0c86c1fd11
2 changed files with 16 additions and 4 deletions

View File

@ -607,12 +607,14 @@ def power(x1, x2):
# TODO(phawkins): add integer pow support to XLA.
bits = 6 # Anything more would overflow for any x1 > 1
acc = ones(shape(x1), dtype=dtype)
zero = _constant_like(x2, 0)
one = _constant_like(x2, 1)
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
for _ in range(bits):
acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
lax.mul(acc, x1), acc)
acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
x1 = lax.mul(x1, x1)
x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
x2 = lax.shift_right_logical(x2, one)
return acc

View File

@ -1754,6 +1754,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertLen(eqns, 1)
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y}
for x in [-1, 0, 1]
for y in [0, 32, 64, 128]))
def testIntegerPowerOverflow(self, x, y):
# Regression test for https://github.com/google/jax/issues/5987
args_maker = lambda: [x, y]
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
self._CompileAndCheck(jnp.power, args_maker)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),