rocm_jax/tests/mutable_array_test.py
Matthew Johnson 4d2808c115 [mutable-arrays] limit implicit ref_swap dtype promotion
fixes #27683

In b7715e279, specifically this line:

b7715e279d (diff-8a1ad6e3b750565d66d30dbf4c9df0825bf5e87c4721e3352f44efbfb8b4a29cR193)

we started ignoring the value dtype completely when it was weakly typed. But that could lead to surprising implicit bitcasts like in #27683. A repro looks like:

```python
import jax.numpy as jnp
from jax._src import core

v = core.mutable_array(jnp.array([0, 0, 0]))
v[...] += 1.0
print(v)  # MutableArray([1065353216, 1065353216, 1065353216], dtype=int32)
```

We can't easily just drop this behavior because it seems many GPU x64 tests depend on it.

So in this change we're trying to
1. do the casting outside the bind, so that in jaxpr typechecking we can assert the value to assign has to match the ref dtype;
2. make that casting more restrictive, supporting only casts on weak-typed values between different precisions of floats or ints; and
3. do an ordinary cast rather than a bitcast.

I left a TODO to change this behavior, since it seems a bit ad-hoc. But we may not want to remove all implicit casting; for example, it's probably reasonable to support implicit casting of Python builtin numeric types when we don't lose any precision, e.g.

```python
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0  # don't error!
```

But we can do that with special-purpose carve-outs for Python builtin numerictypes. I left one way to do it in a comment.

PiperOrigin-RevId: 745198669
2025-04-08 10:25:10 -07:00

414 lines
12 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import jax
from jax._src import core
from jax._src import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding, PartitionSpec as P, AxisType
import jax.numpy as jnp
from jax._src.state.types import (RefEffect)
config.parse_flags_with_absl()
class MutableArrayTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
def test_basic(self, jit):
def f(x_mut):
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5
if jit:
f = jax.jit(f)
x_mut = core.mutable_array(jnp.zeros(3))
f(x_mut)
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)
jaxpr = jax.make_jaxpr(f)(x_mut)
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
@parameterized.parameters([True, False])
def test_multiple_inputs_and_outputs(self, jit):
def f(x_mut, y, z_mut, w):
x_mut[...] += 1
z_mut[...] += 1
return x_mut[...] + y + z_mut[...] + w, y + w
if jit:
f = jax.jit(f)
x_mut = core.mutable_array(jnp.zeros((1, 3)))
y = jnp.ones((2, 3))
z_mut = core.mutable_array(jnp.zeros((2, 3)))
w = jnp.ones((2, 1))
out1, out2 = f(x_mut, y, z_mut, w)
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out2, y + w, check_dtypes=False)
@parameterized.parameters([True, False])
def test_closed_over_basic(self, jit):
x_mut = core.mutable_array(jnp.zeros(3))
def f():
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5
if jit:
f = jax.jit(f)
f()
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)
jaxpr = jax.make_jaxpr(f)()
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
@parameterized.parameters([True, False])
def test_closed_over_nested(self, jit):
x_mut = core.mutable_array(jnp.zeros(3))
@jax.jit
def f(y_mut, z):
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5
y_mut[2] += 7
return z + 9
if jit:
f = jax.jit(f)
y_mut = core.mutable_array(np.zeros(3))
w = f(y_mut, 1)
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
check_dtypes=False)
self.assertAllClose(w, 10, check_dtypes=False)
@parameterized.parameters([True, False])
def test_len_mutable_array(self, jit):
x_mut = core.mutable_array(jnp.zeros(3))
def f():
return jnp.int32(len(x_mut))
if jit:
f = jax.jit(f)
self.assertEqual(f(), 3)
@parameterized.parameters([True, False])
def test_internal_mutarray_basic(self, jit):
def f():
x_mut = core.mutable_array(jnp.zeros(3))
x_mut[0] += 1
x_mut[0] += 1
x_mut[2] += 1
return x_mut[...]
if jit:
f = jax.jit(f)
out = f()
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_internal_mut_array(self, jit):
def body_fun(_, x):
x_mut = core.mutable_array(x)
x_mut[...] += 2
return ((), x_mut[...])
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
if jit:
doit = jax.jit(doit)
_, xs = doit()
self.assertAllClose(xs, (np.arange(5) + 2), check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_closed_over_mut_array(self, jit):
x_mut = core.mutable_array(0)
def body_fun(_, x):
x_mut[...] += 2
return ((), x_mut[...])
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
if jit:
doit = jax.jit(doit)
_, xs = doit()
self.assertAllClose(x_mut[...], 10)
self.assertAllClose(xs, np.arange(5) * 2 + 2, check_dtypes=False)
@parameterized.parameters([True, False])
def test_scan_scanned_mut_array(self, jit):
def body_fun(_, index_x):
(index, x) = index_x
x[...] += index
return ((), x[...])
x_mut = core.mutable_array(np.arange(5))
doit = lambda: jax.lax.scan(body_fun, (), (np.arange(5), x_mut))
if jit:
doit = jax.jit(doit)
_, xs = doit()
self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False)
def test_double_jit_mutable_array(self):
@jax.jit
@jax.jit
def f():
x_ref = core.mutable_array(jnp.zeros(8))
return x_ref[...]
x = f()
self.assertArraysEqual(x, jnp.zeros(8))
def test_grad_mutable_array(self):
@jax.jit
def f(x):
x_ = core.mutable_array(x)
x_[()] = x_[()] + x_[()]
y = core.freeze(x_)
return y
ans = jax.grad(f)(1.)
expected = 2.0
self.assertAllClose(ans, expected, check_dtypes=False)
def test_defensive_copy(self):
x = jnp.arange(3.)
_ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x))
x + 1 # don't crash
def test_sharding_persists(self):
mesh = jtu.create_mesh((1,), ('i',))
x = jax.device_put(jnp.arange(2), NamedSharding(mesh, P('i')))
s = x.sharding
a = core.mutable_array(x)
self.assertEqual(s, a.sharding)
self.assertEqual(s, a[...].sharding)
f = jax.jit(lambda: a[...])
y = f()
self.assertEqual(s, a.sharding)
self.assertEqual(s, y.sharding)
def test_explicit_sharding_after_indexing(self):
# https://github.com/jax-ml/jax/issues/26936
mesh = jtu.create_mesh((1, 1), ('x', 'y'),
axis_types=(AxisType.Explicit,) * 2)
sharding = NamedSharding(mesh, P('x', 'y'))
@jax.jit
def f(x_ref):
self.assertEqual(core.typeof(x_ref).sharding.spec,
core.typeof(x_ref[...]).sharding.spec)
y = x_ref[...] + 1
return y
with jax.sharding.use_mesh(mesh):
x = jnp.zeros((4, 4), jnp.int32, device=sharding)
x_ref = core.mutable_array(x)
y = f(x_ref)
def test_vmap_basic(self):
@jax.vmap
def f(x):
x_ref = core.mutable_array(x)
x_ref[...] = x_ref[...] * x_ref[...]
return x_ref[...]
xs = jnp.arange(4.)
ys = f(xs)
self.assertAllClose(ys, xs ** 2, check_dtypes=False)
def test_implicit_bitcast_regression(self):
# https://github.com/jax-ml/jax/issues/27683
v = core.mutable_array(jnp.array([0, 0, 0]))
with self.assertRaises(ValueError):
v[...] += 1.0
def test_implicit_cast_in_swap(self):
v = core.mutable_array(jnp.array(0, dtype='bfloat16'))
v[...] += 1.0 # don't crash
@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):
def test_return_from_jit(self):
with self.assertRaisesRegex(
ValueError,
r"traced for jit returned a mutable array reference.*\n\n"
r".*was created on line"):
jax.jit(core.mutable_array)(jnp.arange(3))
def test_return_from_jit_arg(self):
with self.assertRaisesRegex(
ValueError,
r"traced for jit returned a mutable array reference.*\n\n"
r".*was passed in as the argument x_ref"):
jax.jit(lambda x_ref: x_ref)(core.mutable_array(jnp.arange(3)))
def test_return_from_jit_pytree(self):
with self.assertRaisesRegex(
ValueError,
r"tree path result\['hi'\]"):
jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3)))
def test_return_from_jit_closure(self):
with self.assertRaisesRegex(
ValueError,
r"tree path result\['hi'\]"):
x_ref = core.mutable_array(jnp.arange(3))
jax.jit(lambda: {'hi': x_ref})()
def test_argument_aliases_jit(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "appeared at both x_ref and y_ref"):
jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref)
def test_closure_and_argument_aliases_jit(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "closed over and passed as the argument y_ref"):
jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref)
def test_return_from_scan(self):
with self.assertRaisesRegex(
ValueError, "traced for scan returned a mutable array reference of type"):
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
def test_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"appeared at both c\[0\] and c\[1\]"):
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)
def test_closure_and_argument_aliases_scan(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
None, length=1)
def test_return_from_cond(self):
with self.assertRaisesRegex(
ValueError, "traced for cond returned a mutable array reference of type"):
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
def test_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
def test_closure_and_argument_aliases_cond(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, r"closed over and passed as the argument y_ref"):
jax.lax.cond(True,
lambda y_ref: x_ref[...] + y_ref[...],
lambda y_ref: x_ref[...] + y_ref[...],
x_ref)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(ref):
return ref
f.defvjp(lambda ref: ..., lambda *_: ...)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp primal function"):
f(x_ref)
@parameterized.parameters([False, True])
def test_return_from_custom_vjp_fwd(self, jit):
@jax.custom_vjp
def f(x, ref):
return x
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "custom_vjp fwd function"):
jax.vjp(f, 3., x_ref)
@parameterized.parameters([False, True])
def test_argument_aliases_custom_vjp_primal(self, jit):
@jax.custom_vjp
def f(x_ref, y_ref):
...
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
if jit:
f = jax.jit(f)
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
f(x_ref, x_ref)
# TODO(mattjj): re-enable test after direct-linearize
# @parameterized.parameters([False, True])
# def test_argument_aliases_custom_vjp_fwd(self, jit):
# @jax.custom_vjp
# def f(x_ref, y_ref):
# ...
# f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
# if jit:
# f = jax.jit(f)
# x_ref = core.mutable_array(0.)
# with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
# jax.vjp(f, x_ref, x_ref)
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
@parameterized.parameters([False, True])
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# see also test_cond_with_ref_reuse in state_test.py
x_ref = core.mutable_array(0.)
def f(pred):
def true_fun():
x_ref[()] = 1.
def false_fun():
x_ref[()] = 2.
jax.lax.cond(pred, true_fun, false_fun)
if jit:
f = jax.jit(f)
out_true = f(True)
self.assertAllClose(x_ref[...], 1.)
out_false = f(False)
self.assertAllClose(x_ref[...], 2.)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())