mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jnp.power: fix overflow case for x1=0
This commit is contained in:
parent
0b88b0ea9b
commit
0c86c1fd11
@ -607,12 +607,14 @@ def power(x1, x2):
|
||||
|
||||
# TODO(phawkins): add integer pow support to XLA.
|
||||
bits = 6 # Anything more would overflow for any x1 > 1
|
||||
acc = ones(shape(x1), dtype=dtype)
|
||||
zero = _constant_like(x2, 0)
|
||||
one = _constant_like(x2, 1)
|
||||
# Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
|
||||
acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
|
||||
for _ in range(bits):
|
||||
acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
|
||||
lax.mul(acc, x1), acc)
|
||||
acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
|
||||
x1 = lax.mul(x1, x1)
|
||||
x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
|
||||
x2 = lax.shift_right_logical(x2, one)
|
||||
return acc
|
||||
|
||||
|
||||
|
@ -1754,6 +1754,16 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertLen(eqns, 1)
|
||||
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_x={}_y={}".format(x, y), "x": x, "y": y}
|
||||
for x in [-1, 0, 1]
|
||||
for y in [0, 32, 64, 128]))
|
||||
def testIntegerPowerOverflow(self, x, y):
|
||||
# Regression test for https://github.com/google/jax/issues/5987
|
||||
args_maker = lambda: [x, y]
|
||||
self._CheckAgainstNumpy(np.power, jnp.power, args_maker)
|
||||
self._CompileAndCheck(jnp.power, args_maker)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_shape={}_axis={}".format(
|
||||
jtu.format_shape_dtype_string(shape, dtype), axis),
|
||||
|
Loading…
x
Reference in New Issue
Block a user