mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #24874 from pearu:pearu/square_p
PiperOrigin-RevId: 696251565
This commit is contained in:
commit
14e08aa271
@ -1915,7 +1915,7 @@ def batch_matmul(lhs: Array, rhs: Array,
|
||||
|
||||
def square(x: ArrayLike) -> Array:
|
||||
r"""Elementwise square: :math:`x^2`."""
|
||||
return integer_pow(x, 2)
|
||||
return square_p.bind(x)
|
||||
|
||||
def reciprocal(x: ArrayLike) -> Array:
|
||||
r"""Elementwise reciprocal: :math:`1 \over x`."""
|
||||
@ -2524,6 +2524,27 @@ ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
|
||||
|
||||
square_p = standard_unop(_int | _float | _complex, 'square')
|
||||
|
||||
def _square_complex(x):
|
||||
a, b = real(x), imag(x)
|
||||
# zero square(x).real is handled explicitly for abs(a)==abs(b) cases
|
||||
# where for finite a, 2 * a is non-finite:
|
||||
zero_re = is_finite(a) & (eq(a, b) | eq(a, -b))
|
||||
# equivalent to a**2 - b**2 but avoids overflow errors for large a
|
||||
# and large b cases:
|
||||
re = (a - b) * (a + b)
|
||||
im = a * b * 2
|
||||
return select(zero_re, complex(_const(a, 0), im), complex(re, im))
|
||||
|
||||
def _square_lower_hlo(ctx, x):
|
||||
if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating):
|
||||
return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x)
|
||||
return [hlo.multiply(x, x)]
|
||||
|
||||
ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x)))
|
||||
mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square
|
||||
|
||||
def _pow_dtype_rule(x, y):
|
||||
if (dtypes.issubdtype(x.dtype, np.inexact) and
|
||||
dtypes.issubdtype(y.dtype, np.integer)):
|
||||
|
@ -3107,7 +3107,7 @@ def square(x: ArrayLike, /) -> Array:
|
||||
"""
|
||||
check_arraylike("square", x)
|
||||
x, = promote_dtypes_numeric(x)
|
||||
return lax.integer_pow(x, 2)
|
||||
return lax.square(x)
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
|
@ -2084,6 +2084,15 @@ def _sqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule
|
||||
|
||||
|
||||
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer):
|
||||
return arith.muli(x, x)
|
||||
return arith.mulf(x, x)
|
||||
|
||||
|
||||
lowering_rules[lax.square_p] = _square_lowering_rule
|
||||
|
||||
|
||||
def _exp_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
return math.exp(x)
|
||||
|
||||
|
@ -1160,6 +1160,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y):
|
||||
return x * x
|
||||
return NotImplementedError
|
||||
|
||||
@register_lowering_rule(lax.square_p)
|
||||
def _square_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
[x_aval] = ctx.avals_in
|
||||
x = _ensure_fa(x, x_aval.dtype)
|
||||
return x * x
|
||||
|
||||
@register_lowering_rule(lax.rsqrt_p)
|
||||
def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
|
||||
|
@ -780,6 +780,7 @@ triton_lowering_rules.update({
|
||||
_Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)),
|
||||
],
|
||||
),
|
||||
lax.square_p: lambda ctx, x: _mul(x, x),
|
||||
lax.pow_p: _make_dispatch_table(
|
||||
"pow",
|
||||
cuda=[
|
||||
|
@ -1726,6 +1726,7 @@ tf_impl[lax.atanh_p] = tf.math.atanh
|
||||
tf_impl[lax.asinh_p] = tf.math.asinh
|
||||
|
||||
tf_impl[lax.sqrt_p] = tf.math.sqrt
|
||||
tf_impl[lax.square_p] = tf.math.square
|
||||
tf_impl[lax.rsqrt_p] = tf.math.rsqrt
|
||||
|
||||
def _cbrt(x):
|
||||
|
@ -405,6 +405,7 @@ def def_comp(prim, comp):
|
||||
def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
|
||||
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
|
||||
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
|
||||
def_comp(lax.square_p, lambda x: x * x)
|
||||
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
|
||||
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
|
||||
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
|
||||
|
@ -97,6 +97,7 @@ _zero_preserving_unary_primitives = [
|
||||
lax.sin_p,
|
||||
lax.sinh_p,
|
||||
lax.sqrt_p,
|
||||
lax.square_p,
|
||||
lax.tan_p,
|
||||
lax.tanh_p,
|
||||
lax.convert_element_type_p,
|
||||
|
@ -127,6 +127,7 @@ from jax._src.lax.lax import (
|
||||
sinh_p as sinh_p,
|
||||
sort_p as sort_p,
|
||||
sqrt_p as sqrt_p,
|
||||
square_p as square_p,
|
||||
squeeze_p as squeeze_p,
|
||||
sub_p as sub_p,
|
||||
tan_p as tan_p,
|
||||
|
@ -206,6 +206,7 @@ from jax._src.lax.lax import (
|
||||
sqrt as sqrt,
|
||||
sqrt_p as sqrt_p,
|
||||
square as square,
|
||||
square_p as square_p,
|
||||
squeeze as squeeze,
|
||||
squeeze_p as squeeze_p,
|
||||
stop_gradient as stop_gradient,
|
||||
|
@ -4362,12 +4362,6 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
elif name == 'sign':
|
||||
regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4')
|
||||
|
||||
elif name == 'square':
|
||||
if is_cuda:
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real')
|
||||
if is_cpu:
|
||||
regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real')
|
||||
|
||||
elif name == 'log':
|
||||
regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag')
|
||||
|
||||
@ -4411,7 +4405,7 @@ class FunctionAccuracyTest(jtu.JaxTestCase):
|
||||
regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag')
|
||||
|
||||
elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan',
|
||||
'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}:
|
||||
'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh', 'square'}:
|
||||
regions_with_inaccuracies.clear()
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
|
Loading…
x
Reference in New Issue
Block a user