jnp.ndarray: raise TypeError for binary operations with builtin collections

This commit is contained in:
Jake VanderPlas 2022-06-24 10:03:27 -07:00
parent 7011de56ef
commit 39b0ff7eb6
4 changed files with 82 additions and 47 deletions

View File

@ -14,6 +14,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
* Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`).
* Binary operations between JAX arrays and built-in collections (`dict`, `list`, `set`, `tuple`)
now raise a `TypeError` in all cases. Previously some cases (particularly equality and inequality)
would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}`#11234`).
## jaxlib 0.3.15 (Unreleased)

View File

@ -4575,20 +4575,22 @@ def _not_implemented(fun):
_scalar_types = (int, float, complex, np.generic)
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, ndarray)
_rejected_binop_types = (list, tuple, set, dict)
def _defer_to_unrecognized_arg(binary_op):
def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
# Ensure that other array types have the chance to override arithmetic.
def deferring_binary_op(self, other):
if hasattr(other, '__jax_array__'):
other = other.__jax_array__()
if not isinstance(other, _accepted_binop_types):
args = (other, self) if swap else (self, other)
if isinstance(other, _accepted_binop_types):
return binary_op(*args)
if isinstance(other, _rejected_binop_types):
raise TypeError(f"unsupported operand type(s) for {opchar}: "
f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
return NotImplemented
return binary_op(self, other)
return deferring_binary_op
def _swap_args(f):
return lambda x, y: f(y, x)
def _unimplemented_setitem(self, i, x):
msg = ("'{}' object does not support item assignment. JAX arrays are "
"immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` "
@ -4615,44 +4617,44 @@ _operators = {
"deepcopy": _deepcopy,
"neg": negative,
"pos": positive,
"eq": _defer_to_unrecognized_arg(equal),
"ne": _defer_to_unrecognized_arg(not_equal),
"lt": _defer_to_unrecognized_arg(less),
"le": _defer_to_unrecognized_arg(less_equal),
"gt": _defer_to_unrecognized_arg(greater),
"ge": _defer_to_unrecognized_arg(greater_equal),
"eq": _defer_to_unrecognized_arg("==", equal),
"ne": _defer_to_unrecognized_arg("!=", not_equal),
"lt": _defer_to_unrecognized_arg("<", less),
"le": _defer_to_unrecognized_arg("<=", less_equal),
"gt": _defer_to_unrecognized_arg(">", greater),
"ge": _defer_to_unrecognized_arg(">=", greater_equal),
"abs": abs,
"add": _defer_to_unrecognized_arg(add),
"radd": _defer_to_unrecognized_arg(add),
"sub": _defer_to_unrecognized_arg(subtract),
"rsub": _defer_to_unrecognized_arg(_swap_args(subtract)),
"mul": _defer_to_unrecognized_arg(multiply),
"rmul": _defer_to_unrecognized_arg(multiply),
"div": _defer_to_unrecognized_arg(divide),
"rdiv": _defer_to_unrecognized_arg(_swap_args(divide)),
"truediv": _defer_to_unrecognized_arg(true_divide),
"rtruediv": _defer_to_unrecognized_arg(_swap_args(true_divide)),
"floordiv": _defer_to_unrecognized_arg(floor_divide),
"rfloordiv": _defer_to_unrecognized_arg(_swap_args(floor_divide)),
"divmod": _defer_to_unrecognized_arg(divmod),
"rdivmod": _defer_to_unrecognized_arg(_swap_args(divmod)),
"mod": _defer_to_unrecognized_arg(mod),
"rmod": _defer_to_unrecognized_arg(_swap_args(mod)),
"pow": _defer_to_unrecognized_arg(power),
"rpow": _defer_to_unrecognized_arg(_swap_args(power)),
"matmul": _defer_to_unrecognized_arg(matmul),
"rmatmul": _defer_to_unrecognized_arg(_swap_args(matmul)),
"and": _defer_to_unrecognized_arg(bitwise_and),
"rand": _defer_to_unrecognized_arg(bitwise_and),
"or": _defer_to_unrecognized_arg(bitwise_or),
"ror": _defer_to_unrecognized_arg(bitwise_or),
"xor": _defer_to_unrecognized_arg(bitwise_xor),
"rxor": _defer_to_unrecognized_arg(bitwise_xor),
"add": _defer_to_unrecognized_arg("+", add),
"radd": _defer_to_unrecognized_arg("+", add, swap=True),
"sub": _defer_to_unrecognized_arg("-", subtract),
"rsub": _defer_to_unrecognized_arg("-", subtract, swap=True),
"mul": _defer_to_unrecognized_arg("*", multiply),
"rmul": _defer_to_unrecognized_arg("*", multiply, swap=True),
"div": _defer_to_unrecognized_arg("/", divide),
"rdiv": _defer_to_unrecognized_arg("/", divide, swap=True),
"truediv": _defer_to_unrecognized_arg("/", true_divide),
"rtruediv": _defer_to_unrecognized_arg("/", true_divide, swap=True),
"floordiv": _defer_to_unrecognized_arg("//", floor_divide),
"rfloordiv": _defer_to_unrecognized_arg("//", floor_divide, swap=True),
"divmod": _defer_to_unrecognized_arg("divmod", divmod),
"rdivmod": _defer_to_unrecognized_arg("divmod", divmod, swap=True),
"mod": _defer_to_unrecognized_arg("%", mod),
"rmod": _defer_to_unrecognized_arg("%", mod, swap=True),
"pow": _defer_to_unrecognized_arg("**", power),
"rpow": _defer_to_unrecognized_arg("**", power, swap=True),
"matmul": _defer_to_unrecognized_arg("@", matmul),
"rmatmul": _defer_to_unrecognized_arg("@", matmul, swap=True),
"and": _defer_to_unrecognized_arg("&", bitwise_and),
"rand": _defer_to_unrecognized_arg("&", bitwise_and, swap=True),
"or": _defer_to_unrecognized_arg("|", bitwise_or),
"ror": _defer_to_unrecognized_arg("|", bitwise_or, swap=True),
"xor": _defer_to_unrecognized_arg("^", bitwise_xor),
"rxor": _defer_to_unrecognized_arg("^", bitwise_xor, swap=True),
"invert": bitwise_not,
"lshift": _defer_to_unrecognized_arg(left_shift),
"rshift": _defer_to_unrecognized_arg(right_shift),
"rlshift": _defer_to_unrecognized_arg(_swap_args(left_shift)),
"rrshift": _defer_to_unrecognized_arg(_swap_args(right_shift)),
"lshift": _defer_to_unrecognized_arg("<<", left_shift),
"rshift": _defer_to_unrecognized_arg(">>", right_shift),
"rlshift": _defer_to_unrecognized_arg("<<", left_shift, swap=True),
"rrshift": _defer_to_unrecognized_arg(">>", right_shift, swap=True),
"round": _operator_round,
}

View File

@ -941,8 +941,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
identity=True
transforms=()
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
_:f32[] = mul 2.00 c
d:f32[] = mul 2.00 1.00
e:f32[] = outside_call[
arg_treedef={treedef}
callback=...
@ -960,8 +960,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
callback=...
identity=True
] b
_:f32[] = mul c 2.00
d:f32[] = mul 1.00 2.00
_:f32[] = mul 2.00 c
d:f32[] = mul 2.00 1.00
e:f32[] = mul d 3.00
in (e,) }}""", jaxpr)
assertMultiLineStrippedEqual(self, "", testing_stream.output)

View File

@ -641,6 +641,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
for othertype in [dict, list, tuple, set]))
def testOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
data = [(1, 2), (2, 3)]
arr = jnp.array(data)
other = othertype(data)
msg = f"unsupported operand type.* 'DeviceArray' and '{othertype.__name__}'"
with self.assertRaisesRegex(TypeError, msg):
getattr(arr, name)(other)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2
for othertype in [dict, list, tuple, set]))
def testRightOperatorOverloadErrors(self, name, othertype):
# Test that binary operators with builtin collections raise a TypeError
# and report the types in the correct order.
data = [(1, 2), (2, 3)]
arr = jnp.array(data)
other = othertype(data)
msg = f"unsupported operand type.* '{othertype.__name__}' and 'DeviceArray'"
with self.assertRaisesRegex(TypeError, msg):
getattr(arr, name)(other)
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": rec.test_name + f"_{dtype}",
"rng_factory": rec.rng_factory,