mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16419 from mattjj:pow-jvp
PiperOrigin-RevId: 559266945
This commit is contained in:
commit
af42359433
@ -57,13 +57,9 @@ from jax._src.interpreters import xla
|
||||
from jax._src.interpreters.batching import RaggedAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype,
|
||||
dtype_to_string,
|
||||
standard_abstract_eval,
|
||||
standard_multi_result_abstract_eval,
|
||||
standard_named_shape_rule,
|
||||
standard_primitive,
|
||||
)
|
||||
_input_dtype, dtype_to_string, standard_abstract_eval,
|
||||
standard_multi_result_abstract_eval, standard_named_shape_rule,
|
||||
standard_primitive)
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -1515,7 +1511,8 @@ ad_util.jaxval_zeros_likers[array.ArrayImpl] = zeros_like_array
|
||||
### primitives
|
||||
|
||||
|
||||
_fixed_dtype = lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
||||
_fixed_dtype = \
|
||||
lambda dtype: lambda *args, **kwargs: dtypes.canonicalize_dtype(dtype)
|
||||
_complex_basetype = lambda dtype: np.abs(np.zeros((), dtype)).dtype
|
||||
|
||||
_strip_weak_type = lambda *args, **_: False
|
||||
@ -1543,7 +1540,7 @@ _attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
|
||||
|
||||
|
||||
def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
|
||||
allow_extended_dtype=False, **kwargs):
|
||||
require_same=True, allow_extended_dtype=False, **kwargs):
|
||||
del kwargs
|
||||
assert len(avals) == len(accepted_dtypes), (avals, accepted_dtypes)
|
||||
for i, aval in enumerate(avals):
|
||||
@ -1566,7 +1563,7 @@ def naryop_dtype_rule(result_dtype, accepted_dtypes, name, *avals,
|
||||
typename = dtype_to_string(aval.dtype)
|
||||
typenames = ', '.join(t.__name__ for t in types)
|
||||
raise TypeError(msg.format(name, typename, i, i, typenames))
|
||||
check_same_dtypes(name, *avals)
|
||||
if require_same: check_same_dtypes(name, *avals)
|
||||
return result_dtype(*avals)
|
||||
|
||||
|
||||
@ -1609,9 +1606,11 @@ def _naryop_weak_type_rule(name, *avals, **kwargs):
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
return all(aval.weak_type for aval in avals)
|
||||
|
||||
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False):
|
||||
def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False,
|
||||
require_same_dtypes=False):
|
||||
dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name,
|
||||
allow_extended_dtype=allow_extended_dtype)
|
||||
allow_extended_dtype=allow_extended_dtype,
|
||||
require_same=require_same_dtypes)
|
||||
shape_rule = partial(broadcasting_shape_rule, name)
|
||||
weak_type_rule = partial(_naryop_weak_type_rule, name)
|
||||
prim = standard_primitive(shape_rule, dtype_rule, name,
|
||||
@ -1973,16 +1972,48 @@ ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))
|
||||
|
||||
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
|
||||
def _pow_dtype_rule(x, y):
|
||||
if (dtypes.issubdtype(x.dtype, np.inexact) and
|
||||
dtypes.issubdtype(y.dtype, np.integer)):
|
||||
return x.dtype
|
||||
if x.dtype == y.dtype:
|
||||
return x.dtype
|
||||
raise TypeError("the first argument to pow must have an inexact dtype (float "
|
||||
"or complex), and the second argument must have an inexact or"
|
||||
" integer dtype, and two inexact dtypes must match, but got "
|
||||
f"{x.dtype} and {y.dtype} respectively.")
|
||||
pow_p = naryop(_pow_dtype_rule, [_float | _complex, _int | _float | _complex],
|
||||
'pow', require_same_dtypes=False)
|
||||
|
||||
def _pow_jvp_lhs(g, ans, x, y):
|
||||
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))
|
||||
y_dtype = dtypes.dtype(y)
|
||||
x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this
|
||||
if dtypes.issubdtype(y_dtype, np.integer):
|
||||
jac = select(eq(y, _const(y, 0)), _ones(y),
|
||||
mul(_replace_zero(y), pow(x, sub(y, _ones(y)))))
|
||||
else:
|
||||
jac = mul(y, pow(x, sub(y, _ones(y))))
|
||||
return mul(g, jac)
|
||||
|
||||
def _pow_jvp_rhs(g, ans, x, y):
|
||||
return mul(g, mul(log(_replace_zero(x)), ans))
|
||||
|
||||
y_dtype = dtypes.dtype(y)
|
||||
assert dtypes.issubdtype(y_dtype, np.inexact)
|
||||
return convert_element_type(mul(g, mul(log(_replace_zero(x)), ans)), y_dtype)
|
||||
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))
|
||||
|
||||
def _pow_lower(ctx, x, y):
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
convert = mlir.lower_fun(
|
||||
partial(convert_element_type, new_dtype=out_aval.dtype), False)
|
||||
x_aval_ = x_aval.update(dtype=out_aval.dtype)
|
||||
y_aval_ = y_aval.update(dtype=out_aval.dtype)
|
||||
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
||||
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
||||
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
|
||||
return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_)
|
||||
mlir.register_lowering(pow_p, _pow_lower)
|
||||
|
||||
|
||||
|
||||
def _integer_pow_dtype_rule(x, *, y):
|
||||
|
@ -18,7 +18,6 @@
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -30,7 +29,8 @@ import numpy as np
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
_input_dtype: Callable = lambda *args, **_: dtypes.canonicalize_dtype(args[0].dtype, allow_extended_dtype=True)
|
||||
def _input_dtype(x, *_, **__):
|
||||
return dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
|
||||
|
||||
def _argnum_weak_type(*argnums):
|
||||
return lambda *args, **_: all(args[i].weak_type for i in argnums)
|
||||
|
@ -32,7 +32,7 @@ from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import (
|
||||
check_arraylike, promote_args, promote_args_inexact,
|
||||
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
|
||||
promote_shapes, _where, _wraps)
|
||||
promote_shapes, _where, _wraps, check_no_float0s)
|
||||
|
||||
_lax_const = lax._const
|
||||
|
||||
@ -305,16 +305,60 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]:
|
||||
return lax.round(div), mod
|
||||
|
||||
|
||||
@_wraps(np.power, module='numpy')
|
||||
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("power", x1, x2)
|
||||
check_no_float0s("power", x1, x2)
|
||||
|
||||
# We apply special cases, both for algorithmic and autodiff reasons:
|
||||
# 1. for *concrete* integer scalar powers (and arbitrary bases), we use
|
||||
# unrolled binary exponentiation specialized on the exponent, which is
|
||||
# more precise for e.g. x ** 2 when x is a float (algorithmic reason!);
|
||||
# 2. for integer bases and integer powers, use unrolled binary exponentiation
|
||||
# where the number of steps is determined by a max bit width of 64
|
||||
# (algorithmic reason!);
|
||||
# 3. for integer powers and float/complex bases, we apply the lax primitive
|
||||
# without any promotion of input types because in this case we want the
|
||||
# function to be differentiable wrt its first argument at 0;
|
||||
# 3. for other cases, perform jnp dtype promotion on the arguments then apply
|
||||
# lax.pow.
|
||||
|
||||
# Case 1: concrete integer scalar powers:
|
||||
if isinstance(core.get_aval(x2), core.ConcreteArray):
|
||||
try:
|
||||
x2 = operator.index(x2) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
x1, = promote_dtypes_numeric(x1)
|
||||
return lax.integer_pow(x1, x2)
|
||||
|
||||
# Handle cases #2 and #3 under a jit:
|
||||
return _power(x1, x2)
|
||||
|
||||
@partial(jit, inline=True)
|
||||
def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
x1, x2 = promote_args_numeric("power", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if not dtypes.issubdtype(dtype, np.integer):
|
||||
x1, x2 = promote_shapes("power", x1, x2) # not dtypes
|
||||
|
||||
# Case 2: bool/integer result
|
||||
x1_, x2_ = promote_args_numeric("power", x1, x2)
|
||||
if (dtypes.issubdtype(dtypes.dtype(x1_), np.integer) or
|
||||
dtypes.issubdtype(dtypes.dtype(x1_), np.bool_)):
|
||||
assert np.iinfo(dtypes.dtype(x1_)).bits <= 64 # _pow_int_int assumes <=64bit
|
||||
return _pow_int_int(x1_, x2_)
|
||||
|
||||
# Case 3: float/complex base with integer power (special autodiff behavior)
|
||||
d1, d2 = dtypes.dtype(x1), dtypes.dtype(x2)
|
||||
if dtypes.issubdtype(d1, np.inexact) and dtypes.issubdtype(d2, np.integer):
|
||||
return lax.pow(x1, x2)
|
||||
|
||||
# Integer power => use binary exponentiation.
|
||||
|
||||
# TODO(phawkins): add integer pow support to XLA.
|
||||
# Case 4: do promotion first
|
||||
return lax.pow(x1_, x2_)
|
||||
|
||||
# TODO(phawkins): add integer pow support to XLA.
|
||||
def _pow_int_int(x1, x2):
|
||||
# Integer power => use binary exponentiation.
|
||||
bits = 6 # Anything more would overflow for any x1 > 1
|
||||
zero = _constant_like(x2, 0)
|
||||
one = _constant_like(x2, 1)
|
||||
@ -327,24 +371,6 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
return acc
|
||||
|
||||
|
||||
@_wraps(np.power, module='numpy')
|
||||
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
check_arraylike("power", x1, x2)
|
||||
# Special case for concrete integer scalars: use binary exponentiation.
|
||||
# Using lax.pow may be imprecise for floating-point values; the goal of this
|
||||
# code path is to make sure we end up with a precise output for the common
|
||||
# pattern ``x ** 2`` or similar.
|
||||
if isinstance(core.get_aval(x2), core.ConcreteArray):
|
||||
try:
|
||||
x2 = operator.index(x2) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
x1, = promote_dtypes_numeric(x1)
|
||||
return lax.integer_pow(x1, x2)
|
||||
return _power(x1, x2)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp, module='numpy')
|
||||
@jit
|
||||
|
@ -337,7 +337,7 @@ def check_arraylike_or_none(fun_name: str, *args: Any):
|
||||
raise TypeError(msg.format(fun_name, type(arg), pos))
|
||||
|
||||
|
||||
def _check_no_float0s(fun_name: str, *args: Any):
|
||||
def check_no_float0s(fun_name: str, *args: Any):
|
||||
"""Check if none of the args have dtype float0."""
|
||||
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
|
||||
raise TypeError(
|
||||
@ -348,6 +348,7 @@ def _check_no_float0s(fun_name: str, *args: Any):
|
||||
"to cast a float0 array to a regular zeros array. \n"
|
||||
"If you didn't expect to get a float0 you might have accidentally "
|
||||
"taken a gradient with respect to an integer argument.")
|
||||
_check_no_float0s = check_no_float0s
|
||||
|
||||
|
||||
def promote_args(fun_name: str, *args: ArrayLike) -> list[Array]:
|
||||
|
@ -1519,10 +1519,10 @@ class JumbleTest(jtu.JaxTestCase):
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.jumble_axis)(ins)
|
||||
p = jumble_map(jax.jit(lambda x: x ** 2))(p)
|
||||
p = jumble_map(jax.jit(lambda x: x * 3))(p)
|
||||
self.assertIsInstance(p, batching.Jumble)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) * 3
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_jumble_map_vector_dot(self):
|
||||
|
@ -538,6 +538,11 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
# self.assertEqual(result, 0.0)
|
||||
self.assertAllClose(result, np.nan)
|
||||
|
||||
def testPowIntPowerAtZero(self):
|
||||
# https://github.com/google/jax/issues/14397
|
||||
ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0)
|
||||
self.assertAllClose(ans, 1., check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(arg_shape=arg_shape, pred_shape=pred_shape)
|
||||
for arg_shape in [(), (3,), (2, 3)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user