mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[mutable-arrays] limit implicit ref_swap dtype promotion
fixes #27683
In b7715e279, specifically this line:
b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)
we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:
```python
import jax.numpy as jnp
from jax._src import core
v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v) # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```
We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.
So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.
I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.
```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0 # don't error!
```
But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.
PiperOrigin-RevId: 745198669
This commit is contained in:
parent
e02faabfb2
commit
4d2808c115
@ -18,9 +18,12 @@ from functools import partial
|
||||
import types
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
@ -40,7 +43,6 @@ from jax._src.state.types import (
|
||||
)
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
import numpy as np
|
||||
|
||||
|
||||
## General utilities
|
||||
@ -142,10 +144,25 @@ def ref_swap(
|
||||
_function_name: str = "ref_swap",
|
||||
) -> Array:
|
||||
"""Sets a `Ref`'s value and returns the original value."""
|
||||
if hasattr(ref_or_view, 'dtype'):
|
||||
value = _maybe_implicit_cast(ref_or_view.dtype, value)
|
||||
ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name)
|
||||
flat_transforms, tree = tree_util.tree_flatten(transforms)
|
||||
return swap_p.bind(ref, value, *flat_transforms, tree=tree)
|
||||
|
||||
# TODO(slebedev,mattjj): replace with special handling of Python numeric types:
|
||||
# if (isinstance(value, (int, float, complex)) and
|
||||
# value == np.array(value, dtype).item()): return cast
|
||||
def _maybe_implicit_cast(dtype, value):
|
||||
aval = core.typeof(value)
|
||||
if (aval.weak_type and
|
||||
(dtypes.issubdtype(dtype, np.floating) and
|
||||
dtypes.issubdtype(aval.dtype, np.floating)) or
|
||||
(dtypes.issubdtype(dtype, np.integer) and
|
||||
dtypes.issubdtype(aval.dtype, np.integer))):
|
||||
return lax.convert_element_type(value, dtype)
|
||||
return value
|
||||
|
||||
|
||||
def ref_set(
|
||||
ref_or_view: AbstractRef | TransformedRef,
|
||||
@ -246,7 +263,7 @@ def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
f"Expected shape: {expected_out_shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Transforms: {transforms}. ")
|
||||
if expected_out_dtype != val_aval.dtype and not val_aval.weak_type:
|
||||
if expected_out_dtype != val_aval.dtype:
|
||||
raise ValueError(
|
||||
"Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {expected_out_dtype}. "
|
||||
|
@ -249,6 +249,16 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
ys = f(xs)
|
||||
self.assertAllClose(ys, xs ** 2, check_dtypes=False)
|
||||
|
||||
def test_implicit_bitcast_regression(self):
|
||||
# https://github.com/jax-ml/jax/issues/27683
|
||||
v = core.mutable_array(jnp.array([0, 0, 0]))
|
||||
with self.assertRaises(ValueError):
|
||||
v[...] += 1.0
|
||||
|
||||
def test_implicit_cast_in_swap(self):
|
||||
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
|
||||
v[...] += 1.0 # don't crash
|
||||
|
||||
|
||||
@jtu.with_config(jax_mutable_array_checks=True)
|
||||
class MutableArrayErrorsTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user