mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
jnp.ndarray: raise TypeError for binary operations with builtin collections
This commit is contained in:
parent
7011de56ef
commit
39b0ff7eb6
@ -14,6 +14,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These
|
||||
classes have been deprecated since v0.3.1 ({jax-issue}`#11248`).
|
||||
* Added {class}`jax.scipy.gaussian_kde` ({jax-issue}`#11237`).
|
||||
* Binary operations between JAX arrays and built-in collections (`dict`, `list`, `set`, `tuple`)
|
||||
now raise a `TypeError` in all cases. Previously some cases (particularly equality and inequality)
|
||||
would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}`#11234`).
|
||||
|
||||
## jaxlib 0.3.15 (Unreleased)
|
||||
|
||||
|
@ -4575,20 +4575,22 @@ def _not_implemented(fun):
|
||||
|
||||
_scalar_types = (int, float, complex, np.generic)
|
||||
_accepted_binop_types = (int, float, complex, np.generic, np.ndarray, ndarray)
|
||||
_rejected_binop_types = (list, tuple, set, dict)
|
||||
|
||||
def _defer_to_unrecognized_arg(binary_op):
|
||||
def _defer_to_unrecognized_arg(opchar, binary_op, swap=False):
|
||||
# Ensure that other array types have the chance to override arithmetic.
|
||||
def deferring_binary_op(self, other):
|
||||
if hasattr(other, '__jax_array__'):
|
||||
other = other.__jax_array__()
|
||||
if not isinstance(other, _accepted_binop_types):
|
||||
args = (other, self) if swap else (self, other)
|
||||
if isinstance(other, _accepted_binop_types):
|
||||
return binary_op(*args)
|
||||
if isinstance(other, _rejected_binop_types):
|
||||
raise TypeError(f"unsupported operand type(s) for {opchar}: "
|
||||
f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
|
||||
return NotImplemented
|
||||
return binary_op(self, other)
|
||||
return deferring_binary_op
|
||||
|
||||
def _swap_args(f):
|
||||
return lambda x, y: f(y, x)
|
||||
|
||||
def _unimplemented_setitem(self, i, x):
|
||||
msg = ("'{}' object does not support item assignment. JAX arrays are "
|
||||
"immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` "
|
||||
@ -4615,44 +4617,44 @@ _operators = {
|
||||
"deepcopy": _deepcopy,
|
||||
"neg": negative,
|
||||
"pos": positive,
|
||||
"eq": _defer_to_unrecognized_arg(equal),
|
||||
"ne": _defer_to_unrecognized_arg(not_equal),
|
||||
"lt": _defer_to_unrecognized_arg(less),
|
||||
"le": _defer_to_unrecognized_arg(less_equal),
|
||||
"gt": _defer_to_unrecognized_arg(greater),
|
||||
"ge": _defer_to_unrecognized_arg(greater_equal),
|
||||
"eq": _defer_to_unrecognized_arg("==", equal),
|
||||
"ne": _defer_to_unrecognized_arg("!=", not_equal),
|
||||
"lt": _defer_to_unrecognized_arg("<", less),
|
||||
"le": _defer_to_unrecognized_arg("<=", less_equal),
|
||||
"gt": _defer_to_unrecognized_arg(">", greater),
|
||||
"ge": _defer_to_unrecognized_arg(">=", greater_equal),
|
||||
"abs": abs,
|
||||
"add": _defer_to_unrecognized_arg(add),
|
||||
"radd": _defer_to_unrecognized_arg(add),
|
||||
"sub": _defer_to_unrecognized_arg(subtract),
|
||||
"rsub": _defer_to_unrecognized_arg(_swap_args(subtract)),
|
||||
"mul": _defer_to_unrecognized_arg(multiply),
|
||||
"rmul": _defer_to_unrecognized_arg(multiply),
|
||||
"div": _defer_to_unrecognized_arg(divide),
|
||||
"rdiv": _defer_to_unrecognized_arg(_swap_args(divide)),
|
||||
"truediv": _defer_to_unrecognized_arg(true_divide),
|
||||
"rtruediv": _defer_to_unrecognized_arg(_swap_args(true_divide)),
|
||||
"floordiv": _defer_to_unrecognized_arg(floor_divide),
|
||||
"rfloordiv": _defer_to_unrecognized_arg(_swap_args(floor_divide)),
|
||||
"divmod": _defer_to_unrecognized_arg(divmod),
|
||||
"rdivmod": _defer_to_unrecognized_arg(_swap_args(divmod)),
|
||||
"mod": _defer_to_unrecognized_arg(mod),
|
||||
"rmod": _defer_to_unrecognized_arg(_swap_args(mod)),
|
||||
"pow": _defer_to_unrecognized_arg(power),
|
||||
"rpow": _defer_to_unrecognized_arg(_swap_args(power)),
|
||||
"matmul": _defer_to_unrecognized_arg(matmul),
|
||||
"rmatmul": _defer_to_unrecognized_arg(_swap_args(matmul)),
|
||||
"and": _defer_to_unrecognized_arg(bitwise_and),
|
||||
"rand": _defer_to_unrecognized_arg(bitwise_and),
|
||||
"or": _defer_to_unrecognized_arg(bitwise_or),
|
||||
"ror": _defer_to_unrecognized_arg(bitwise_or),
|
||||
"xor": _defer_to_unrecognized_arg(bitwise_xor),
|
||||
"rxor": _defer_to_unrecognized_arg(bitwise_xor),
|
||||
"add": _defer_to_unrecognized_arg("+", add),
|
||||
"radd": _defer_to_unrecognized_arg("+", add, swap=True),
|
||||
"sub": _defer_to_unrecognized_arg("-", subtract),
|
||||
"rsub": _defer_to_unrecognized_arg("-", subtract, swap=True),
|
||||
"mul": _defer_to_unrecognized_arg("*", multiply),
|
||||
"rmul": _defer_to_unrecognized_arg("*", multiply, swap=True),
|
||||
"div": _defer_to_unrecognized_arg("/", divide),
|
||||
"rdiv": _defer_to_unrecognized_arg("/", divide, swap=True),
|
||||
"truediv": _defer_to_unrecognized_arg("/", true_divide),
|
||||
"rtruediv": _defer_to_unrecognized_arg("/", true_divide, swap=True),
|
||||
"floordiv": _defer_to_unrecognized_arg("//", floor_divide),
|
||||
"rfloordiv": _defer_to_unrecognized_arg("//", floor_divide, swap=True),
|
||||
"divmod": _defer_to_unrecognized_arg("divmod", divmod),
|
||||
"rdivmod": _defer_to_unrecognized_arg("divmod", divmod, swap=True),
|
||||
"mod": _defer_to_unrecognized_arg("%", mod),
|
||||
"rmod": _defer_to_unrecognized_arg("%", mod, swap=True),
|
||||
"pow": _defer_to_unrecognized_arg("**", power),
|
||||
"rpow": _defer_to_unrecognized_arg("**", power, swap=True),
|
||||
"matmul": _defer_to_unrecognized_arg("@", matmul),
|
||||
"rmatmul": _defer_to_unrecognized_arg("@", matmul, swap=True),
|
||||
"and": _defer_to_unrecognized_arg("&", bitwise_and),
|
||||
"rand": _defer_to_unrecognized_arg("&", bitwise_and, swap=True),
|
||||
"or": _defer_to_unrecognized_arg("|", bitwise_or),
|
||||
"ror": _defer_to_unrecognized_arg("|", bitwise_or, swap=True),
|
||||
"xor": _defer_to_unrecognized_arg("^", bitwise_xor),
|
||||
"rxor": _defer_to_unrecognized_arg("^", bitwise_xor, swap=True),
|
||||
"invert": bitwise_not,
|
||||
"lshift": _defer_to_unrecognized_arg(left_shift),
|
||||
"rshift": _defer_to_unrecognized_arg(right_shift),
|
||||
"rlshift": _defer_to_unrecognized_arg(_swap_args(left_shift)),
|
||||
"rrshift": _defer_to_unrecognized_arg(_swap_args(right_shift)),
|
||||
"lshift": _defer_to_unrecognized_arg("<<", left_shift),
|
||||
"rshift": _defer_to_unrecognized_arg(">>", right_shift),
|
||||
"rlshift": _defer_to_unrecognized_arg("<<", left_shift, swap=True),
|
||||
"rrshift": _defer_to_unrecognized_arg(">>", right_shift, swap=True),
|
||||
"round": _operator_round,
|
||||
}
|
||||
|
||||
|
@ -941,8 +941,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
identity=True
|
||||
transforms=()
|
||||
] b
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
_:f32[] = mul 2.00 c
|
||||
d:f32[] = mul 2.00 1.00
|
||||
e:f32[] = outside_call[
|
||||
arg_treedef={treedef}
|
||||
callback=...
|
||||
@ -960,8 +960,8 @@ class HostCallbackTapTest(jtu.JaxTestCase):
|
||||
callback=...
|
||||
identity=True
|
||||
] b
|
||||
_:f32[] = mul c 2.00
|
||||
d:f32[] = mul 1.00 2.00
|
||||
_:f32[] = mul 2.00 c
|
||||
d:f32[] = mul 2.00 1.00
|
||||
e:f32[] = mul d 3.00
|
||||
in (e,) }}""", jaxpr)
|
||||
assertMultiLineStrippedEqual(self, "", testing_stream.output)
|
||||
|
@ -641,6 +641,36 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
with jtu.strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CompileAndCheck( fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
|
||||
for rec in JAX_OPERATOR_OVERLOADS if rec.nargs == 2
|
||||
for othertype in [dict, list, tuple, set]))
|
||||
def testOperatorOverloadErrors(self, name, othertype):
|
||||
# Test that binary operators with builtin collections raise a TypeError
|
||||
# and report the types in the correct order.
|
||||
data = [(1, 2), (2, 3)]
|
||||
arr = jnp.array(data)
|
||||
other = othertype(data)
|
||||
|
||||
msg = f"unsupported operand type.* 'DeviceArray' and '{othertype.__name__}'"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
getattr(arr, name)(other)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": f"{rec.test_name}_{othertype}", "name": rec.name, "othertype": othertype}
|
||||
for rec in JAX_RIGHT_OPERATOR_OVERLOADS if rec.nargs == 2
|
||||
for othertype in [dict, list, tuple, set]))
|
||||
def testRightOperatorOverloadErrors(self, name, othertype):
|
||||
# Test that binary operators with builtin collections raise a TypeError
|
||||
# and report the types in the correct order.
|
||||
data = [(1, 2), (2, 3)]
|
||||
arr = jnp.array(data)
|
||||
other = othertype(data)
|
||||
|
||||
msg = f"unsupported operand type.* '{othertype.__name__}' and 'DeviceArray'"
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
getattr(arr, name)(other)
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": rec.test_name + f"_{dtype}",
|
||||
"rng_factory": rec.rng_factory,
|
||||
|
Loading…
x
Reference in New Issue
Block a user