Mark arguments to ufuncs as positional-only.

PiperOrigin-RevId: 493311821
This commit is contained in:
Peter Hawkins 2022-12-06 08:23:40 -08:00 committed by jax authors
parent a7900166d1
commit 33a1b8866a
3 changed files with 49 additions and 52 deletions

View File

@ -6,13 +6,6 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
Remember to align the itemized text with the first line of an item within a list.
-->
## Next jax
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.4.0...main).
* Changes
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
## jax 0.4.0
* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
@ -36,7 +29,11 @@ Remember to align the itemized text with the first line of an item within a list
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
`jax.config` options, which are used pervasively in JAX.
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
## jaxlib 0.4.0
* Changes

View File

@ -59,9 +59,9 @@ def _one_to_one_unop(
numpy_fn: Callable[..., Any], lax_fn: UnOp,
promote_to_inexact: bool = False, lax_doc: bool = False) -> UnOp:
if promote_to_inexact:
fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
fn = lambda x, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
else:
fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
fn = lambda x, /: lax_fn(*_promote_args(numpy_fn.__name__, x))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
@ -76,11 +76,11 @@ def _one_to_one_binop(
promote_to_inexact: bool = False, lax_doc: bool = False,
promote_to_numeric: bool = False) -> BinOp:
if promote_to_inexact:
fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
elif promote_to_numeric:
fn = lambda x1, x2: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*_promote_args_numeric(numpy_fn.__name__, x1, x2))
else:
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
fn = lambda x1, x2, /: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
fn = jit(fn, inline=True)
if lax_doc:
@ -93,7 +93,7 @@ def _one_to_one_binop(
def _maybe_bool_binop(
numpy_fn: Callable[..., Any], lax_fn: BinOp, bool_lax_fn: BinOp,
lax_doc: bool = False) -> BinOp:
def fn(x1, x2):
def fn(x1, x2, /):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
fn.__qualname__ = f"jax.numpy.{numpy_fn.__name__}"
@ -106,7 +106,7 @@ def _maybe_bool_binop(
def _comparison_op(numpy_fn: Callable[..., Any], lax_fn: BinOp) -> BinOp:
def fn(x1, x2):
def fn(x1, x2, /):
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
# Comparison on complex types are defined as a lexicographic ordering on
# the (real, imag) pair.
@ -191,7 +191,7 @@ logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
@_wraps(np.arccosh, module='numpy')
@jit
def arccosh(x: ArrayLike) -> Array:
def arccosh(x: ArrayLike, /) -> Array:
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*_promote_args_inexact("arccosh", x))
@ -202,7 +202,7 @@ def arccosh(x: ArrayLike) -> Array:
@_wraps(np.right_shift, module='numpy')
@partial(jit, inline=True)
def right_shift(x1: ArrayLike, x2: ArrayLike) -> Array:
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric(np.right_shift.__name__, x1, x2)
lax_fn = lax.shift_right_logical if \
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
@ -211,7 +211,7 @@ def right_shift(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.absolute, module='numpy')
@partial(jit, inline=True)
def absolute(x: ArrayLike) -> Array:
def absolute(x: ArrayLike, /) -> Array:
_check_arraylike('absolute', x)
dt = dtypes.dtype(x)
return _asarray(x) if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x)
@ -220,7 +220,7 @@ abs = _wraps(np.abs, module='numpy')(absolute)
@_wraps(np.rint, module='numpy')
@jit
def rint(x: ArrayLike) -> Array:
def rint(x: ArrayLike, /) -> Array:
_check_arraylike('rint', x)
dtype = dtypes.dtype(x)
if dtype == bool or dtypes.issubdtype(dtype, np.integer):
@ -232,7 +232,7 @@ def rint(x: ArrayLike) -> Array:
@_wraps(np.sign, module='numpy')
@jit
def sign(x: ArrayLike) -> Array:
def sign(x: ArrayLike, /) -> Array:
_check_arraylike('sign', x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.complexfloating):
@ -244,7 +244,7 @@ def sign(x: ArrayLike) -> Array:
@_wraps(np.copysign, module='numpy')
@jit
def copysign(x1: ArrayLike, x2: ArrayLike) -> Array:
def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("copysign", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
raise TypeError("copysign does not support complex-valued inputs")
@ -253,7 +253,7 @@ def copysign(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.true_divide, module='numpy')
@partial(jit, inline=True)
def true_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("true_divide", x1, x2)
return lax.div(x1, x2)
@ -262,7 +262,7 @@ divide = true_divide
@_wraps(np.floor_divide, module='numpy')
@jit
def floor_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric("floor_divide", x1, x2)
dtype = dtypes.dtype(x1)
if dtypes.issubdtype(dtype, np.integer):
@ -287,7 +287,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.divmod, module='numpy')
@jit
def divmod(x1: ArrayLike, x2: ArrayLike) -> Tuple[Array, Array]:
def divmod(x1: ArrayLike, x2: ArrayLike, /) -> Tuple[Array, Array]:
x1, x2 = _promote_args_numeric("divmod", x1, x2)
if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
return floor_divide(x1, x2), remainder(x1, x2)
@ -330,7 +330,7 @@ def _power(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.power, module='numpy')
def power(x1: ArrayLike, x2: ArrayLike) -> Array:
def power(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("power", x1, x2)
# Special case for concrete integer scalars: use binary exponentiation.
# Using lax.pow may be imprecise for floating-point values; the goal of this
@ -350,7 +350,7 @@ def power(x1: ArrayLike, x2: ArrayLike) -> Array:
@custom_jvp
@_wraps(np.logaddexp, module='numpy')
@jit
def logaddexp(x1: ArrayLike, x2: ArrayLike) -> Array:
def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
@ -388,7 +388,7 @@ def _logaddexp_jvp(primals, tangents):
@custom_jvp
@_wraps(np.logaddexp2, module='numpy')
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike) -> Array:
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
@ -416,28 +416,28 @@ def _logaddexp2_jvp(primals, tangents):
@_wraps(np.log2, module='numpy')
@partial(jit, inline=True)
def log2(x: ArrayLike) -> Array:
def log2(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("log2", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
@_wraps(np.log10, module='numpy')
@partial(jit, inline=True)
def log10(x: ArrayLike) -> Array:
def log10(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("log10", x)
return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike) -> Array:
def exp2(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
@_wraps(np.signbit, module='numpy')
@jit
def signbit(x: ArrayLike) -> Array:
def signbit(x: ArrayLike, /) -> Array:
x, = _promote_args("signbit", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.integer):
@ -474,7 +474,7 @@ def _normalize_float(x):
@_wraps(np.ldexp, module='numpy')
@jit
def ldexp(x1: ArrayLike, x2: ArrayLike) -> Array:
def ldexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("ldexp", x1, x2)
x1_dtype = dtypes.dtype(x1)
x2_dtype = dtypes.dtype(x2)
@ -523,7 +523,7 @@ def ldexp(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.frexp, module='numpy')
@jit
def frexp(x: ArrayLike) -> Tuple[Array, Array]:
def frexp(x: ArrayLike, /) -> Tuple[Array, Array]:
_check_arraylike("frexp", x)
x, = _promote_dtypes_inexact(x)
if dtypes.issubdtype(x.dtype, np.complexfloating):
@ -547,7 +547,7 @@ def frexp(x: ArrayLike) -> Tuple[Array, Array]:
@_wraps(np.remainder, module='numpy')
@jit
def remainder(x1: ArrayLike, x2: ArrayLike) -> Array:
def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array:
x1, x2 = _promote_args_numeric("remainder", x1, x2)
zero = _constant_like(x1, 0)
trunc_mod = lax.rem(x1, x2)
@ -560,7 +560,7 @@ mod = _wraps(np.mod, module='numpy')(remainder)
@_wraps(np.fmod, module='numpy')
@jit
def fmod(x1: ArrayLike, x2: ArrayLike) -> Array:
def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("fmod", x1, x2)
if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
@ -569,7 +569,7 @@ def fmod(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.square, module='numpy')
@partial(jit, inline=True)
def square(x: ArrayLike) -> Array:
def square(x: ArrayLike, /) -> Array:
_check_arraylike("square", x)
x, = _promote_dtypes_numeric(x)
return lax.integer_pow(x, 2)
@ -577,14 +577,14 @@ def square(x: ArrayLike) -> Array:
@_wraps(np.deg2rad, module='numpy')
@partial(jit, inline=True)
def deg2rad(x: ArrayLike) -> Array:
def deg2rad(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("deg2rad", x)
return lax.mul(x, _lax_const(x, np.pi / 180))
@_wraps(np.rad2deg, module='numpy')
@partial(jit, inline=True)
def rad2deg(x: ArrayLike) -> Array:
def rad2deg(x: ArrayLike, /) -> Array:
x, = _promote_args_inexact("rad2deg", x)
return lax.mul(x, _lax_const(x, 180 / np.pi))
@ -595,7 +595,7 @@ radians = deg2rad
@_wraps(np.conjugate, module='numpy')
@partial(jit, inline=True)
def conjugate(x: ArrayLike) -> Array:
def conjugate(x: ArrayLike, /) -> Array:
_check_arraylike("conjugate", x)
return lax.conj(x) if np.iscomplexobj(x) else _asarray(x)
conj = conjugate
@ -603,20 +603,20 @@ conj = conjugate
@_wraps(np.imag)
@partial(jit, inline=True)
def imag(val: ArrayLike) -> Array:
def imag(val: ArrayLike, /) -> Array:
_check_arraylike("imag", val)
return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0)
@_wraps(np.real)
@partial(jit, inline=True)
def real(val: ArrayLike) -> Array:
def real(val: ArrayLike, /) -> Array:
_check_arraylike("real", val)
return lax.real(val) if np.iscomplexobj(val) else _asarray(val)
@_wraps(np.modf, module='numpy', skip_params=['out'])
@jit
def modf(x: ArrayLike, out=None) -> Tuple[Array, Array]:
def modf(x: ArrayLike, /, out=None) -> Tuple[Array, Array]:
_check_arraylike("modf", x)
x, = _promote_dtypes_inexact(x)
if out is not None:
@ -627,7 +627,7 @@ def modf(x: ArrayLike, out=None) -> Tuple[Array, Array]:
@_wraps(np.isfinite, module='numpy')
@jit
def isfinite(x: ArrayLike) -> Array:
def isfinite(x: ArrayLike, /) -> Array:
_check_arraylike("isfinite", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
@ -640,7 +640,7 @@ def isfinite(x: ArrayLike) -> Array:
@_wraps(np.isinf, module='numpy')
@jit
def isinf(x: ArrayLike) -> Array:
def isinf(x: ArrayLike, /) -> Array:
_check_arraylike("isinf", x)
dtype = dtypes.dtype(x)
if dtypes.issubdtype(dtype, np.floating):
@ -667,25 +667,25 @@ def _isposneginf(infinity: float, x: ArrayLike, out) -> Array:
isposinf: UnOp = _wraps(np.isposinf, skip_params=['out'])(
lambda x, out=None: _isposneginf(np.inf, x, out)
lambda x, /, out=None: _isposneginf(np.inf, x, out)
)
isneginf: UnOp = _wraps(np.isneginf, skip_params=['out'])(
lambda x, out=None: _isposneginf(-np.inf, x, out)
lambda x, /, out=None: _isposneginf(-np.inf, x, out)
)
@_wraps(np.isnan, module='numpy')
@jit
def isnan(x: ArrayLike) -> Array:
def isnan(x: ArrayLike, /) -> Array:
_check_arraylike("isnan", x)
return lax.ne(x, x)
@_wraps(np.heaviside, module='numpy')
@jit
def heaviside(x1: ArrayLike, x2: ArrayLike) -> Array:
def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("heaviside", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
zero = _lax_const(x1, 0)
@ -695,7 +695,7 @@ def heaviside(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.hypot, module='numpy')
@jit
def hypot(x1: ArrayLike, x2: ArrayLike) -> Array:
def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array:
_check_arraylike("hypot", x1, x2)
x1, x2 = _promote_dtypes_inexact(x1, x2)
x1 = lax.abs(x1)
@ -706,7 +706,7 @@ def hypot(x1: ArrayLike, x2: ArrayLike) -> Array:
@_wraps(np.reciprocal, module='numpy')
@partial(jit, inline=True)
def reciprocal(x: ArrayLike) -> Array:
def reciprocal(x: ArrayLike, /) -> Array:
_check_arraylike("reciprocal", x)
x, = _promote_dtypes_inexact(x)
return lax.integer_pow(x, -1)
@ -714,7 +714,7 @@ def reciprocal(x: ArrayLike) -> Array:
@_wraps(np.sinc, update_doc=False)
@jit
def sinc(x: ArrayLike) -> Array:
def sinc(x: ArrayLike, /) -> Array:
_check_arraylike("sinc", x)
x, = _promote_dtypes_inexact(x)
eq_zero = lax.eq(x, _lax_const(x, 0))

View File

@ -1337,7 +1337,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
@parameterized.parameters(['int', 'np.int', 'jnp.int'])
def testIntegerPower(self, ptype):
p = {'int': 2, 'np.int': np.int32(2), 'jnp.int': jnp.int32(2)}[ptype]
jaxpr = jax.make_jaxpr(partial(jnp.power, x2=p))(1)
jaxpr = jax.make_jaxpr(lambda x1: jnp.power(x1, p))(1)
eqns = jaxpr.jaxpr.eqns
self.assertLen(eqns, 1)
self.assertEqual(eqns[0].primitive, lax.integer_pow_p)