[mutable-arrays] limit implicit ref_swap dtype promotion

fixes #27683

In b7715e279, specifically this line:

b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
This commit is contained in:
Matthew Johnson 2025-04-08 10:24:14 -07:00 committed by jax authors
parent e02faabfb2
commit 4d2808c115
2 changed files with 29 additions and 2 deletions

View File

@ -18,9 +18,12 @@ from functools import partial
import types
from typing import Any, Union
import numpy as np
from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import traceback_util
from jax._src import tree_util
@ -40,7 +43,6 @@ from jax._src.state.types import (
)
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip
import numpy as np
## General utilities
@ -142,10 +144,25 @@ def ref_swap(
_function_name: str = "ref_swap",
) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
if hasattr(ref_or_view, 'dtype'):
value = _maybe_implicit_cast(ref_or_view.dtype, value)
ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name)
flat_transforms, tree = tree_util.tree_flatten(transforms)
return swap_p.bind(ref, value, *flat_transforms, tree=tree)
# TODO(slebedev,mattjj): replace with special handling of Python numeric types:
# if (isinstance(value, (int, float, complex)) and
# value == np.array(value, dtype).item()): return cast
def _maybe_implicit_cast(dtype, value):
aval = core.typeof(value)
if (aval.weak_type and
(dtypes.issubdtype(dtype, np.floating) and
dtypes.issubdtype(aval.dtype, np.floating)) or
(dtypes.issubdtype(dtype, np.integer) and
dtypes.issubdtype(aval.dtype, np.integer))):
return lax.convert_element_type(value, dtype)
return value
def ref_set(
ref_or_view: AbstractRef | TransformedRef,
@ -246,7 +263,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
f"Expected shape: {expected_out_shape}. "
f"Value shape: {val_aval.shape}. "
f"Transforms: {transforms}. ")
if expected_out_dtype != val_aval.dtype and not val_aval.weak_type:
if expected_out_dtype != val_aval.dtype:
raise ValueError(
"Invalid dtype for `swap`. "
f"Ref dtype: {expected_out_dtype}. "

View File

@ -249,6 +249,16 @@ class MutableArrayTest(jtu.JaxTestCase):
ys = f(xs)
self.assertAllClose(ys, xs ** 2, check_dtypes=False)
def test_implicit_bitcast_regression(self):
# https://github.com/jax-ml/jax/issues/27683
v = core.mutable_array(jnp.array([0, 0, 0]))
with self.assertRaises(ValueError):
v[...] += 1.0
def test_implicit_cast_in_swap(self):
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0 # don't crash
@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):