mirror of
https://github.com/ROCm/jax.git
synced 2025-04-26 05:06:07 +00:00

Previously, we represented a missing arg name with `None`, and a missing result path with the empty string. We now adopt the same convention for arg names and use empty strings. This simplifies the typing, and prevents the string "None" from appearing in error messages. I changed how we encode the result paths. Previously for a function that returns a single array the path was the empty string (the same as for an unknown path). And for a function that returns a pair of arrays it was `([0], [1])`. Now we add the "result" prefix: `("result",)` for a function returning a single array and `(result[0], result[1])` for a function returning a pair of arrays. Finally, in debug_info_test, I removed the `check_tracer_arg_name` so that all spied tracers are printed with the argument name they depend on.
408 lines
12 KiB
Python
408 lines
12 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
import jax
|
|
from jax._src import core
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax.sharding import NamedSharding, PartitionSpec as P
|
|
import jax.numpy as jnp
|
|
|
|
from jax._src.state.types import (RefEffect)
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
class MutableArrayTest(jtu.JaxTestCase):
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_basic(self, jit):
|
|
def f(x_mut):
|
|
x_mut[...] += 1.
|
|
x_mut[0] += 1
|
|
x_mut[1] += 5
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
x_mut = core.mutable_array(jnp.zeros(3))
|
|
f(x_mut)
|
|
|
|
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
|
check_dtypes=False)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(x_mut)
|
|
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_multiple_inputs_and_outputs(self, jit):
|
|
def f(x_mut, y, z_mut, w):
|
|
x_mut[...] += 1
|
|
z_mut[...] += 1
|
|
return x_mut[...] + y + z_mut[...] + w, y + w
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
x_mut = core.mutable_array(jnp.zeros((1, 3)))
|
|
y = jnp.ones((2, 3))
|
|
z_mut = core.mutable_array(jnp.zeros((2, 3)))
|
|
w = jnp.ones((2, 1))
|
|
|
|
out1, out2 = f(x_mut, y, z_mut, w)
|
|
|
|
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
|
|
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
|
|
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
|
|
self.assertAllClose(out2, y + w, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_closed_over_basic(self, jit):
|
|
x_mut = core.mutable_array(jnp.zeros(3))
|
|
def f():
|
|
x_mut[...] += 1.
|
|
x_mut[0] += 1
|
|
x_mut[1] += 5
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
f()
|
|
|
|
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
|
check_dtypes=False)
|
|
|
|
jaxpr = jax.make_jaxpr(f)()
|
|
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_closed_over_nested(self, jit):
|
|
x_mut = core.mutable_array(jnp.zeros(3))
|
|
|
|
@jax.jit
|
|
def f(y_mut, z):
|
|
x_mut[...] += 1.
|
|
x_mut[0] += 1
|
|
x_mut[1] += 5
|
|
|
|
y_mut[2] += 7
|
|
return z + 9
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
y_mut = core.mutable_array(np.zeros(3))
|
|
|
|
w = f(y_mut, 1)
|
|
|
|
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
|
check_dtypes=False)
|
|
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
|
|
check_dtypes=False)
|
|
self.assertAllClose(w, 10, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_internal_mutarray_basic(self, jit):
|
|
def f():
|
|
x_mut = core.mutable_array(jnp.zeros(3))
|
|
x_mut[0] += 1
|
|
x_mut[0] += 1
|
|
x_mut[2] += 1
|
|
return x_mut[...]
|
|
|
|
if jit:
|
|
f = jax.jit(f)
|
|
|
|
out = f()
|
|
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_refs_in_vjps(self, jit):
|
|
def gradient_history_calculator_fwd(x, ref):
|
|
return x, ref
|
|
|
|
def gradient_history_calculator_bwd(amax_history, grad_output):
|
|
amax_update = jnp.max(jnp.abs(grad_output))
|
|
shifted = jnp.roll(amax_history[:], 1)
|
|
shifted = shifted.at[0].set(amax_update)
|
|
amax_history[:] = shifted
|
|
amax_from_history = jnp.max(amax_history[:])
|
|
grad_output = grad_output / amax_from_history
|
|
return grad_output, None
|
|
|
|
@jax.custom_vjp
|
|
def gradient_history_calculator(x, ref):
|
|
return x
|
|
|
|
gradient_history_calculator.defvjp(
|
|
gradient_history_calculator_fwd,
|
|
gradient_history_calculator_bwd)
|
|
|
|
class DotOp:
|
|
def __init__(self):
|
|
self.amax_history = core.mutable_array(jnp.zeros(5,))
|
|
|
|
def forward(self, x, y):
|
|
out = jnp.dot(x, y)
|
|
out = gradient_history_calculator(out, self.amax_history)
|
|
return out
|
|
|
|
dot_op = DotOp()
|
|
x_top = jnp.ones((5,))
|
|
y_top = jnp.ones((5,))
|
|
|
|
def loss(x, y):
|
|
return dot_op.forward(x, y).sum()
|
|
|
|
if jit:
|
|
loss = jax.jit(loss)
|
|
|
|
for i in range(3):
|
|
jax.grad(loss, (0,1))(x_top, y_top)
|
|
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_scan_internal_mut_array(self, jit):
|
|
def body_fun(_, x):
|
|
x_mut = core.mutable_array(x)
|
|
x_mut[...] += 2
|
|
return ((), x_mut[...])
|
|
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
|
|
if jit:
|
|
doit = jax.jit(doit)
|
|
_, xs = doit()
|
|
self.assertAllClose(xs, (np.arange(5) + 2), check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_scan_closed_over_mut_array(self, jit):
|
|
x_mut = core.mutable_array(0)
|
|
def body_fun(_, x):
|
|
x_mut[...] += 2
|
|
return ((), x_mut[...])
|
|
|
|
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
|
|
if jit:
|
|
doit = jax.jit(doit)
|
|
_, xs = doit()
|
|
self.assertAllClose(x_mut[...], 10)
|
|
self.assertAllClose(xs, np.arange(5) * 2 + 2, check_dtypes=False)
|
|
|
|
@parameterized.parameters([True, False])
|
|
def test_scan_scanned_mut_array(self, jit):
|
|
def body_fun(_, index_x):
|
|
(index, x) = index_x
|
|
x[...] += index
|
|
return ((), x[...])
|
|
|
|
x_mut = core.mutable_array(np.arange(5))
|
|
doit = lambda: jax.lax.scan(body_fun, (), (np.arange(5), x_mut))
|
|
if jit:
|
|
doit = jax.jit(doit)
|
|
_, xs = doit()
|
|
self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False)
|
|
|
|
def test_double_jit_mutable_array(self):
|
|
@jax.jit
|
|
@jax.jit
|
|
def f():
|
|
x_ref = core.mutable_array(jnp.zeros(8))
|
|
return x_ref[...]
|
|
x = f()
|
|
self.assertArraysEqual(x, jnp.zeros(8))
|
|
|
|
def test_grad_mutable_array(self):
|
|
@jax.jit
|
|
def f(x):
|
|
x_ = core.mutable_array(x)
|
|
x_[()] = x_[()] + x_[()]
|
|
y = core.freeze(x_)
|
|
return y
|
|
|
|
ans = jax.grad(f)(1.)
|
|
expected = 2.0
|
|
self.assertAllClose(ans, expected, check_dtypes=False)
|
|
|
|
def test_defensive_copy(self):
|
|
x = jnp.arange(3.)
|
|
_ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x))
|
|
x + 1 # don't crash
|
|
|
|
def test_sharding_persists(self):
|
|
mesh = jax.make_mesh((1,), ('i',))
|
|
x = jax.device_put(jnp.arange(2), NamedSharding(mesh, P('i')))
|
|
s = x.sharding
|
|
a = core.mutable_array(x)
|
|
self.assertEqual(s, a.sharding)
|
|
self.assertEqual(s, a[...].sharding)
|
|
f = jax.jit(lambda: a[...])
|
|
y = f()
|
|
self.assertEqual(s, a.sharding)
|
|
self.assertEqual(s, y.sharding)
|
|
|
|
|
|
@jtu.with_config(jax_mutable_array_checks=True)
|
|
class MutableArrayErrorsTest(jtu.JaxTestCase):
|
|
def test_return_from_jit(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"traced for jit returned a mutable array reference.*\n\n"
|
|
r".*was created on line"):
|
|
jax.jit(core.mutable_array)(jnp.arange(3))
|
|
|
|
def test_return_from_jit_arg(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"traced for jit returned a mutable array reference.*\n\n"
|
|
r".*was passed in as the argument x_ref"):
|
|
jax.jit(lambda x_ref: x_ref)(core.mutable_array(jnp.arange(3)))
|
|
|
|
def test_return_from_jit_pytree(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"tree path result\['hi'\]"):
|
|
jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3)))
|
|
|
|
def test_return_from_jit_closure(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"tree path result\['hi'\]"):
|
|
x_ref = core.mutable_array(jnp.arange(3))
|
|
jax.jit(lambda: {'hi': x_ref})()
|
|
|
|
def test_argument_aliases_jit(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "appeared at both x_ref and y_ref"):
|
|
jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref)
|
|
|
|
def test_closure_and_argument_aliases_jit(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "closed over and passed as the argument y_ref"):
|
|
jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref)
|
|
|
|
def test_return_from_scan(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError, "traced for scan returned a mutable array reference of type"):
|
|
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))
|
|
|
|
def test_argument_aliases_scan(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"appeared at both c\[0\] and c\[1\]"):
|
|
jax.lax.scan(lambda c, _: (None, None), (x_ref, x_ref), None, length=1)
|
|
|
|
def test_closure_and_argument_aliases_scan(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"closed over and passed as the argument y_ref"):
|
|
jax.lax.scan(lambda y_ref, _: (x_ref[...] + y_ref[...], None), x_ref,
|
|
None, length=1)
|
|
|
|
def test_return_from_cond(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError, "traced for cond returned a mutable array reference of type"):
|
|
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
|
|
|
|
def test_argument_aliases_cond(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
|
|
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
|
|
|
|
def test_closure_and_argument_aliases_cond(self):
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"closed over and passed as the argument y_ref"):
|
|
jax.lax.cond(True,
|
|
lambda y_ref: x_ref[...] + y_ref[...],
|
|
lambda y_ref: x_ref[...] + y_ref[...],
|
|
x_ref)
|
|
|
|
@parameterized.parameters([False, True])
|
|
def test_return_from_custom_vjp_primal(self, jit):
|
|
@jax.custom_vjp
|
|
def f(ref):
|
|
return ref
|
|
f.defvjp(lambda ref: ..., lambda *_: ...)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "custom_vjp primal function"):
|
|
f(x_ref)
|
|
|
|
@parameterized.parameters([False, True])
|
|
def test_return_from_custom_vjp_fwd(self, jit):
|
|
@jax.custom_vjp
|
|
def f(x, ref):
|
|
return x
|
|
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "custom_vjp fwd function"):
|
|
jax.vjp(f, 3., x_ref)
|
|
|
|
@parameterized.parameters([False, True])
|
|
def test_argument_aliases_custom_vjp_primal(self, jit):
|
|
@jax.custom_vjp
|
|
def f(x_ref, y_ref):
|
|
...
|
|
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
|
if jit:
|
|
f = jax.jit(f)
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
|
f(x_ref, x_ref)
|
|
|
|
@parameterized.parameters([False, True])
|
|
def test_argument_aliases_custom_vjp_fwd(self, jit):
|
|
@jax.custom_vjp
|
|
def f(x_ref, y_ref):
|
|
...
|
|
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
|
|
if jit:
|
|
f = jax.jit(f)
|
|
x_ref = core.mutable_array(0.)
|
|
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
|
|
jax.vjp(f, x_ref, x_ref)
|
|
|
|
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
|
|
|
|
@parameterized.parameters([False, True])
|
|
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
|
|
# see also test_cond_with_ref_reuse in state_test.py
|
|
x_ref = core.mutable_array(0.)
|
|
def f(pred):
|
|
def true_fun():
|
|
x_ref[()] = 1.
|
|
def false_fun():
|
|
x_ref[()] = 2.
|
|
jax.lax.cond(pred, true_fun, false_fun)
|
|
if jit:
|
|
f = jax.jit(f)
|
|
out_true = f(True)
|
|
self.assertAllClose(x_ref[...], 1.)
|
|
out_false = f(False)
|
|
self.assertAllClose(x_ref[...], 2.)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|