mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Mark arguments to ufuncs as positional-only.
PiperOrigin-RevId: 493311821
This commit is contained in:
parent
a7900166d1
commit
33a1b8866a
13
CHANGELOG.md
13
CHANGELOG.md
@ -6,13 +6,6 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
-->
|
||||
|
||||
## Next jax
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.4.0...main).
|
||||
* Changes
|
||||
* The jax2tf.call_tf function now uses for TF lowering the first TF
|
||||
device of the same platform as used by the embedding JAX computation.
|
||||
Before, it was using the 0th device for the JAX-default backend.
|
||||
|
||||
## jax 0.4.0
|
||||
* Changes
|
||||
* Support for Python 3.7 has been dropped, in accordance with JAX's
|
||||
@ -36,7 +29,11 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
longer read or written after the JAX configuration options are initially
|
||||
populated from the ABSL flags. This change improves performance of reading
|
||||
`jax.config` options, which are used pervasively in JAX.
|
||||
|
||||
* The jax2tf.call_tf function now uses for TF lowering the first TF
|
||||
device of the same platform as used by the embedding JAX computation.
|
||||
Before, it was using the 0th device for the JAX-default backend.
|
||||
* A number of `jax.numpy` functions now have their arguments marked as
|
||||
positional-only, matching NumPy.
|
||||
|
||||
## jaxlib 0.4.0
|
||||
* Changes
|
||||
|
@ -59,9 +59,9 @@ def _one_to_one_unop(
|
||||
numpy_fn: Callable[..., Any], lax_fn: UnOp,
|
||||
promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
|
||||
if promote_to_inexact:
|
||||
fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
|
||||
fn = lambda x, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
|
||||
else:
|
||||
fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
|
||||
fn = lambda x, /: lax_fn(*_promote_args(numpy_fn.__name__, x))
|
||||
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
@ -76,11 +76,11 @@ def _one_to_one_binop(
|
||||
promote_to_inexact: bool = False, lax_doc: bool = False,
|
||||
promote_to_numeric: bool = False) -> BinOp:
|
||||
if promote_to_inexact:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
|
||||
fn = lambda x1, x2, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
|
||||
elif promote_to_numeric:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
|
||||
fn = lambda x1, x2, /: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
|
||||
else:
|
||||
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
|
||||
fn = lambda x1, x2, /: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
|
||||
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
|
||||
fn = jit(fn, inline=True)
|
||||
if lax_doc:
|
||||
@ -93,7 +93,7 @@ def _one_to_one_binop(
|
||||
def _maybe_bool_binop(
|
||||
numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
|
||||
lax_doc: bool = False) -> BinOp:
|
||||
def fn(x1, x2):
|
||||
def fn(x1, x2, /):
|
||||
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
|
||||
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
|
||||
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
|
||||
@ -106,7 +106,7 @@ def _maybe_bool_binop(
|
||||
|
||||
|
||||
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
|
||||
def fn(x1, x2):
|
||||
def fn(x1, x2, /):
|
||||
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
|
||||
# Comparison on complex types are defined as a lexicographic ordering on
|
||||
# the (real, imag) pair.
|
||||
@ -191,7 +191,7 @@ logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
|
||||
|
||||
@_wraps(np.arccosh, module='numpy')
|
||||
@jit
|
||||
def arccosh(x: ArrayLike) -> Array:
|
||||
def arccosh(x: ArrayLike, /) -> Array:
|
||||
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
|
||||
# convention than np.arccosh.
|
||||
out = lax.acosh(*_promote_args_inexact("arccosh", x))
|
||||
@ -202,7 +202,7 @@ def arccosh(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.right_shift, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def right_shift(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2)
|
||||
lax_fn = lax.shift_right_logical if \
|
||||
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
|
||||
@ -211,7 +211,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.absolute, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def absolute(x: ArrayLike) -> Array:
|
||||
def absolute(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike('absolute', x)
|
||||
dt = dtypes.dtype(x)
|
||||
return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
|
||||
@ -220,7 +220,7 @@ abs = _wraps(np.abs, module='numpy')(absolute)
|
||||
|
||||
@_wraps(np.rint, module='numpy')
|
||||
@jit
|
||||
def rint(x: ArrayLike) -> Array:
|
||||
def rint(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike('rint', x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtype == bool or dtypes.issubdtype(dtype, np.integer):
|
||||
@ -232,7 +232,7 @@ def rint(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.sign, module='numpy')
|
||||
@jit
|
||||
def sign(x: ArrayLike) -> Array:
|
||||
def sign(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike('sign', x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.complexfloating):
|
||||
@ -244,7 +244,7 @@ def sign(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.copysign, module='numpy')
|
||||
@jit
|
||||
def copysign(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_inexact("copysign", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
|
||||
raise TypeError("copysign does not support complex-valued inputs")
|
||||
@ -253,7 +253,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.true_divide, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def true_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
|
||||
return lax.div(x1, x2)
|
||||
|
||||
@ -262,7 +262,7 @@ divide = true_divide
|
||||
|
||||
@_wraps(np.floor_divide, module='numpy')
|
||||
@jit
|
||||
def floor_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_numeric("floor_divide", x1, x2)
|
||||
dtype = dtypes.dtype(x1)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
@ -287,7 +287,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.divmod, module='numpy')
|
||||
@jit
|
||||
def divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
|
||||
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
x1, x2 = _promote_args_numeric("divmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
|
||||
return floor_divide(x1, x2), remainder(x1, x2)
|
||||
@ -330,7 +330,7 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
|
||||
@_wraps(np.power, module='numpy')
|
||||
def power(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("power", x1, x2)
|
||||
# Special case for concrete integer scalars: use binary exponentiation.
|
||||
# Using lax.pow may be imprecise for floating-point values; the goal of this
|
||||
@ -350,7 +350,7 @@ def power(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp, module='numpy')
|
||||
@jit
|
||||
def logaddexp(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
@ -388,7 +388,7 @@ def _logaddexp_jvp(primals, tangents):
|
||||
@custom_jvp
|
||||
@_wraps(np.logaddexp2, module='numpy')
|
||||
@jit
|
||||
def logaddexp2(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
@ -416,28 +416,28 @@ def _logaddexp2_jvp(primals, tangents):
|
||||
|
||||
@_wraps(np.log2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def log2(x: ArrayLike) -> Array:
|
||||
def log2(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args_inexact("log2", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
|
||||
|
||||
|
||||
@_wraps(np.log10, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def log10(x: ArrayLike) -> Array:
|
||||
def log10(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args_inexact("log10", x)
|
||||
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
|
||||
|
||||
|
||||
@_wraps(np.exp2, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def exp2(x: ArrayLike) -> Array:
|
||||
def exp2(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args_inexact("exp2", x)
|
||||
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
|
||||
|
||||
|
||||
@_wraps(np.signbit, module='numpy')
|
||||
@jit
|
||||
def signbit(x: ArrayLike) -> Array:
|
||||
def signbit(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args("signbit", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.integer):
|
||||
@ -474,7 +474,7 @@ def _normalize_float(x):
|
||||
|
||||
@_wraps(np.ldexp, module='numpy')
|
||||
@jit
|
||||
def ldexp(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("ldexp", x1, x2)
|
||||
x1_dtype = dtypes.dtype(x1)
|
||||
x2_dtype = dtypes.dtype(x2)
|
||||
@ -523,7 +523,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.frexp, module='numpy')
|
||||
@jit
|
||||
def frexp(x: ArrayLike) -> Tuple[Array, Array]:
|
||||
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
|
||||
_check_arraylike("frexp", x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
@ -547,7 +547,7 @@ def frexp(x: ArrayLike) -> Tuple[Array, Array]:
|
||||
|
||||
@_wraps(np.remainder, module='numpy')
|
||||
@jit
|
||||
def remainder(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
x1, x2 = _promote_args_numeric("remainder", x1, x2)
|
||||
zero = _constant_like(x1, 0)
|
||||
trunc_mod = lax.rem(x1, x2)
|
||||
@ -560,7 +560,7 @@ mod = _wraps(np.mod, module='numpy')(remainder)
|
||||
|
||||
@_wraps(np.fmod, module='numpy')
|
||||
@jit
|
||||
def fmod(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("fmod", x1, x2)
|
||||
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
|
||||
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
|
||||
@ -569,7 +569,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.square, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def square(x: ArrayLike) -> Array:
|
||||
def square(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("square", x)
|
||||
x, = _promote_dtypes_numeric(x)
|
||||
return lax.integer_pow(x, 2)
|
||||
@ -577,14 +577,14 @@ def square(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.deg2rad, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def deg2rad(x: ArrayLike) -> Array:
|
||||
def deg2rad(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args_inexact("deg2rad", x)
|
||||
return lax.mul(x, _lax_const(x, np.pi / 180))
|
||||
|
||||
|
||||
@_wraps(np.rad2deg, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def rad2deg(x: ArrayLike) -> Array:
|
||||
def rad2deg(x: ArrayLike, /) -> Array:
|
||||
x, = _promote_args_inexact("rad2deg", x)
|
||||
return lax.mul(x, _lax_const(x, 180 / np.pi))
|
||||
|
||||
@ -595,7 +595,7 @@ radians = deg2rad
|
||||
|
||||
@_wraps(np.conjugate, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def conjugate(x: ArrayLike) -> Array:
|
||||
def conjugate(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("conjugate", x)
|
||||
return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
|
||||
conj = conjugate
|
||||
@ -603,20 +603,20 @@ conj = conjugate
|
||||
|
||||
@_wraps(np.imag)
|
||||
@partial(jit, inline=True)
|
||||
def imag(val: ArrayLike) -> Array:
|
||||
def imag(val: ArrayLike, /) -> Array:
|
||||
_check_arraylike("imag", val)
|
||||
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
|
||||
|
||||
|
||||
@_wraps(np.real)
|
||||
@partial(jit, inline=True)
|
||||
def real(val: ArrayLike) -> Array:
|
||||
def real(val: ArrayLike, /) -> Array:
|
||||
_check_arraylike("real", val)
|
||||
return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
|
||||
|
||||
@_wraps(np.modf, module='numpy', skip_params=['out'])
|
||||
@jit
|
||||
def modf(x: ArrayLike, out=None) -> Tuple[Array, Array]:
|
||||
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
|
||||
_check_arraylike("modf", x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
if out is not None:
|
||||
@ -627,7 +627,7 @@ def modf(x: ArrayLike, out=None) -> Tuple[Array, Array]:
|
||||
|
||||
@_wraps(np.isfinite, module='numpy')
|
||||
@jit
|
||||
def isfinite(x: ArrayLike) -> Array:
|
||||
def isfinite(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("isfinite", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.floating):
|
||||
@ -640,7 +640,7 @@ def isfinite(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.isinf, module='numpy')
|
||||
@jit
|
||||
def isinf(x: ArrayLike) -> Array:
|
||||
def isinf(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("isinf", x)
|
||||
dtype = dtypes.dtype(x)
|
||||
if dtypes.issubdtype(dtype, np.floating):
|
||||
@ -667,25 +667,25 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
|
||||
|
||||
|
||||
isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(np.inf, x, out)
|
||||
lambda x, /, out=None: _isposneginf(np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
|
||||
lambda x, out=None: _isposneginf(-np.inf, x, out)
|
||||
lambda x, /, out=None: _isposneginf(-np.inf, x, out)
|
||||
)
|
||||
|
||||
|
||||
@_wraps(np.isnan, module='numpy')
|
||||
@jit
|
||||
def isnan(x: ArrayLike) -> Array:
|
||||
def isnan(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("isnan", x)
|
||||
return lax.ne(x, x)
|
||||
|
||||
|
||||
@_wraps(np.heaviside, module='numpy')
|
||||
@jit
|
||||
def heaviside(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("heaviside", x1, x2)
|
||||
x1, x2 = _promote_dtypes_inexact(x1, x2)
|
||||
zero = _lax_const(x1, 0)
|
||||
@ -695,7 +695,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.hypot, module='numpy')
|
||||
@jit
|
||||
def hypot(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
_check_arraylike("hypot", x1, x2)
|
||||
x1, x2 = _promote_dtypes_inexact(x1, x2)
|
||||
x1 = lax.abs(x1)
|
||||
@ -706,7 +706,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.reciprocal, module='numpy')
|
||||
@partial(jit, inline=True)
|
||||
def reciprocal(x: ArrayLike) -> Array:
|
||||
def reciprocal(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("reciprocal", x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
return lax.integer_pow(x, -1)
|
||||
@ -714,7 +714,7 @@ def reciprocal(x: ArrayLike) -> Array:
|
||||
|
||||
@_wraps(np.sinc, update_doc=False)
|
||||
@jit
|
||||
def sinc(x: ArrayLike) -> Array:
|
||||
def sinc(x: ArrayLike, /) -> Array:
|
||||
_check_arraylike("sinc", x)
|
||||
x, = _promote_dtypes_inexact(x)
|
||||
eq_zero = lax.eq(x, _lax_const(x, 0))
|
||||
|
@ -1337,7 +1337,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
@parameterized.parameters(['int', 'np.int', 'jnp.int'])
|
||||
def testIntegerPower(self, ptype):
|
||||
p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype]
|
||||
jaxpr = jax.make_jaxpr(partial(jnp.power, x2=p))(1)
|
||||
jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1)
|
||||
eqns = jaxpr.jaxpr.eqns
|
||||
self.assertLen(eqns, 1)
|
||||
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)
|
||||
|
Loading…
x
Reference in New Issue
Block a user