"""Calculates an implementation-dependent approximation of the inverse tangent of the quotient x1/x2, having domain [-infinity, +infinity] x [-infinity, +infinity] (where the x notation denotes the set of ordered pairs of elements (x1_i, x2_i)) and codomain [-π, +π], for each pair of elements (x1_i, x2_i) of the input arrays x1 and x2, respectively."""
x1,x2=_promote_dtypes("atan2",x1,x2)
returnjax.numpy.arctan2(x1,x2)
defatanh(x,/):
"""Calculates an implementation-dependent approximation to the inverse hyperbolic tangent for each element x_i of the input array x."""
"""Computes the bitwise AND of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("bitwise_and",x1,x2)
returnjax.numpy.bitwise_and(x1,x2)
defbitwise_left_shift(x1,x2,/):
"""Shifts the bits of each element x1_i of the input array x1 to the left by appending x2_i (i.e., the respective element in the input array x2) zeros to the right of x1_i."""
"""Computes the bitwise OR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("bitwise_or",x1,x2)
returnjax.numpy.bitwise_or(x1,x2)
defbitwise_right_shift(x1,x2,/):
"""Shifts the bits of each element x1_i of the input array x1 to the right according to the respective element x2_i of the input array x2."""
"""Computes the bitwise XOR of the underlying binary representation of each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("bitwise_xor",x1,x2)
returnjax.numpy.bitwise_xor(x1,x2)
defceil(x,/):
"""Rounds each element x_i of the input array x to the smallest (i.e., closest to -infinity) integer-valued number that is not less than x_i."""
"""Calculates the division of each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("divide",x1,x2)
returnjax.numpy.divide(x1,x2)
defequal(x1,x2,/):
"""Computes the truth value of x1_i == x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("equal",x1,x2)
returnjax.numpy.equal(x1,x2)
defexp(x,/):
"""Calculates an implementation-dependent approximation to the exponential function for each element x_i of the input array x (e raised to the power of x_i, where e is the base of the natural logarithm)."""
"""Rounds the result of dividing each element x1_i of the input array x1 by the respective element x2_i of the input array x2 to the greatest (i.e., closest to +infinity) integer-value number that is not greater than the division result."""
x1,x2=_promote_dtypes("floor_divide",x1,x2)
returnjax.numpy.floor_divide(x1,x2)
defgreater(x1,x2,/):
"""Computes the truth value of x1_i > x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("greater",x1,x2)
returnjax.numpy.greater(x1,x2)
defgreater_equal(x1,x2,/):
"""Computes the truth value of x1_i >= x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
"""Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("hypot",x1,x2)
# TODO(micky774): Remove when jnp.hypot deprecation is completed
# (began 2024-4-14) and default behavior is Array API 2023 compliant
ifissubdtype(x1.dtype,jax.numpy.complexfloating):
raiseValueError(
"hypot does not support complex-valued inputs. Please convert to real "
"values first, such as by using jnp.real or jnp.imag to take the real "
"""Calculates an implementation-dependent approximation to log(1+x), where log refers to the natural (base e) logarithm, for each element x_i of the input array x."""
"""Calculates the logarithm of the sum of exponentiations log(exp(x1) + exp(x2)) for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("logaddexp",x1,x2)
returnjax.numpy.logaddexp(x1,x2)
deflogical_and(x1,x2,/):
"""Computes the logical AND for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("logical_and",x1,x2)
returnjax.numpy.logical_and(x1,x2)
deflogical_not(x,/):
"""Computes the logical NOT for each element x_i of the input array x."""
x,=_promote_dtypes("logical_not",x)
returnjax.numpy.logical_not(x)
deflogical_or(x1,x2,/):
"""Computes the logical OR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("logical_or",x1,x2)
returnjax.numpy.logical_or(x1,x2)
deflogical_xor(x1,x2,/):
"""Computes the logical XOR for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("logical_xor",x1,x2)
returnjax.numpy.logical_xor(x1,x2)
defmultiply(x1,x2,/):
"""Calculates the product for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("multiply",x1,x2)
returnjax.numpy.multiply(x1,x2)
defnegative(x,/):
"""Computes the numerical negative of each element x_i (i.e., y_i = -x_i) of the input array x."""
"""Computes the truth value of x1_i != x2_i for each element x1_i of the input array x1 with the respective element x2_i of the input array x2."""
x1,x2=_promote_dtypes("not_equal",x1,x2)
returnjax.numpy.not_equal(x1,x2)
defpositive(x,/):
"""Computes the numerical positive of each element x_i (i.e., y_i = +x_i) of the input array x."""
x,=_promote_dtypes("positive",x)
returnx
defpow(x1,x2,/):
"""Calculates an implementation-dependent approximation of exponentiation by raising each element x1_i (the base) of the input array x1 to the power of x2_i (the exponent), where x2_i is the corresponding element of the input array x2."""