mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add discharge rules for scan with mutable arrays. Move mutable array tests to separate file.
Co-authored-by: Matt Johnson <mattjj@google.com>
This commit is contained in:
parent
47fdc7b08f
commit
e63b35d550
@ -1999,13 +1999,15 @@ def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
mutable_array_p = Primitive('mutable_array')
|
||||
|
||||
class InternalMutableArray(effects.Effect):
|
||||
class InternalMutableArrayEffect(effects.Effect):
|
||||
pass
|
||||
internal_mutable_array_effect = InternalMutableArrayEffect()
|
||||
effects.control_flow_allowed_effects.add_type(InternalMutableArrayEffect)
|
||||
|
||||
@mutable_array_p.def_effectful_abstract_eval
|
||||
def mutable_array_abstract_eval(init_aval):
|
||||
from jax._src.state.types import AbstractRef # type: ignore[import]
|
||||
return AbstractRef(init_aval), {InternalMutableArray}
|
||||
return AbstractRef(init_aval), {internal_mutable_array_effect}
|
||||
|
||||
@mutable_array_p.def_impl
|
||||
def _mutable_array_impl(init_val):
|
||||
|
@ -2046,7 +2046,7 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
|
||||
assert next(out_layouts_, None) is None
|
||||
else:
|
||||
inout_aliases = mut = None
|
||||
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
|
||||
if any(isinstance(e, core.InternalMutableArrayEffect) for e in closed_jaxpr.effects):
|
||||
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
|
||||
|
||||
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
|
||||
|
@ -55,7 +55,8 @@ from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
|
||||
unzip2, weakref_lru_cache, merge_lists)
|
||||
split_list_checked, unzip2, weakref_lru_cache,
|
||||
merge_lists)
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
@ -1201,57 +1202,74 @@ def _scan_pp_rule(eqn, context, settings):
|
||||
def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts,
|
||||
num_carry, linear, unroll, reverse, length,
|
||||
_split_transpose):
|
||||
jaxpr, consts = jaxpr.jaxpr, jaxpr.consts
|
||||
# We're shuffling parameters between three signatures for the scan body:
|
||||
# jaxpr : (n_consts, n_carry, n_xs) -> (n_carry, n_ys)
|
||||
# discharged : (n_consts, n_carry, n_xs) -> (n_carry, n_ys, n_ref_consts, n_ref_xs)
|
||||
# wrapped : (n_val_consts, (n_ref_consts, n_carry), (n_val_xs, n_ref_xs))
|
||||
# -> ((n_ref_consts, n_carry), (n_ys, n_ref_xs))
|
||||
# where we partition consts and xs between ref and non-ref versions:
|
||||
# n_carry = (n_val_consts, n_ref_consts)
|
||||
# n_xs = (n_val_xs, n_ref_xs)
|
||||
|
||||
# avals from jaxpr (i.e. rank-reduced) rather than from caller
|
||||
jaxpr, in_avals, out_avals, consts = jaxpr.jaxpr, jaxpr.in_avals, jaxpr.out_avals, jaxpr.consts
|
||||
if consts: raise NotImplementedError
|
||||
consts, carry, xs = split_list(args, [num_consts, num_carry])
|
||||
consts_linear, carry_linear, xs_linear = split_list(
|
||||
linear, [num_consts, num_carry])
|
||||
consts_avals, carry_avals, xs_avals = split_list(in_avals,
|
||||
[num_consts, num_carry])
|
||||
is_ref = [isinstance(a, state.AbstractRef) for a in consts_avals]
|
||||
remaining_const_avals, in_ref_avals = partition_list(is_ref, consts_avals)
|
||||
remaining_consts, in_refs = partition_list(is_ref, consts)
|
||||
remaining_consts_linear, in_refs_linear = partition_list(is_ref, consts_linear)
|
||||
num_refs = sum(is_ref)
|
||||
num_extensive_in = len(in_avals) - num_carry - num_consts
|
||||
num_extensive_out = len(out_avals) - num_carry
|
||||
num_remaining_consts = num_consts - num_refs
|
||||
n_consts = num_consts
|
||||
n_carry = num_carry
|
||||
n_xs = len(in_avals) - n_consts - n_carry
|
||||
n_ys = len(out_avals) - n_carry
|
||||
consts_avals, carry_avals, xs_avals = split_list_checked(in_avals,
|
||||
[n_consts, n_carry, n_xs])
|
||||
is_ref_const = [isinstance(a, state.AbstractRef) for a in consts_avals]
|
||||
assert not any(isinstance(a, state.AbstractRef) for a in carry_avals)
|
||||
is_ref_xs = [isinstance(a, state.AbstractRef) for a in xs_avals]
|
||||
n_ref_consts = sum(is_ref_const)
|
||||
n_val_consts = n_consts - n_ref_consts
|
||||
n_ref_xs = sum(is_ref_xs)
|
||||
n_val_xs = n_xs - n_ref_xs
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if discharged_consts:
|
||||
raise NotImplementedError("Discharged jaxpr has consts. If you see this, "
|
||||
"please open an issue at "
|
||||
"https://github.com/google/jax/issues")
|
||||
# The discharged jaxpr will have output refs stashed at the end
|
||||
def wrapped(*refs_and_args):
|
||||
consts, refs, carry, xs = split_list(refs_and_args, [num_remaining_consts,
|
||||
num_refs,
|
||||
num_carry])
|
||||
consts_with_refs = merge_lists(is_ref, consts, refs)
|
||||
outs_and_refs = core.eval_jaxpr(discharged_jaxpr, (), *consts_with_refs,
|
||||
*carry, *xs)
|
||||
carry, ys, out_refs = split_list(outs_and_refs, [num_carry,
|
||||
num_extensive_out])
|
||||
assert len(out_refs) == num_refs
|
||||
return [*out_refs, *carry, *ys]
|
||||
new_in_avals = [*remaining_const_avals, *[a.inner_aval for a in in_ref_avals],
|
||||
*carry_avals,
|
||||
*[core.mapped_aval(length, 0, a) for a in xs_avals]]
|
||||
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), new_in_avals)
|
||||
new_linear = (*remaining_consts_linear, *in_refs_linear,
|
||||
*carry_linear, *xs_linear)
|
||||
all_out = scan_p.bind(*remaining_consts, *in_refs, *carry, *xs,
|
||||
def wrapped(*wrapped_args):
|
||||
val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args,
|
||||
[n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs])
|
||||
consts = merge_lists(is_ref_const, val_consts, ref_consts_in)
|
||||
xs = merge_lists(is_ref_xs, val_xs, ref_xs_in)
|
||||
outs = core.eval_jaxpr(discharged_jaxpr, (), *consts, *carry_in, *xs)
|
||||
carry_out, ys, ref_consts_out, ref_xs_out = split_list_checked(outs,
|
||||
[n_carry, n_ys, n_ref_consts, n_ref_xs])
|
||||
return [*ref_consts_out, *carry_out, *ys, *ref_xs_out]
|
||||
|
||||
def arrange_jaxpr_args_for_wrapped(args):
|
||||
consts, carry_in, xs = split_list_checked(args, [n_consts, n_carry, n_xs])
|
||||
val_consts, ref_consts_in = partition_list(is_ref_const, consts)
|
||||
val_xs, ref_xs_in = partition_list(is_ref_xs, xs)
|
||||
return *val_consts, *ref_consts_in, *carry_in, *val_xs, *ref_xs_in
|
||||
|
||||
args_for_wrapped = arrange_jaxpr_args_for_wrapped(args)
|
||||
linear_for_wrapped = arrange_jaxpr_args_for_wrapped(linear)
|
||||
avals_for_wrapped = arrange_jaxpr_args_for_wrapped(in_avals)
|
||||
avals_for_wrapped_no_refs = [aval.inner_aval if isinstance(aval, state.AbstractRef) else aval
|
||||
for aval in avals_for_wrapped]
|
||||
new_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(wrapped), avals_for_wrapped_no_refs)
|
||||
all_out = scan_p.bind(*args_for_wrapped,
|
||||
jaxpr=core.ClosedJaxpr(new_jaxpr, ()),
|
||||
length=length,
|
||||
num_consts=num_remaining_consts,
|
||||
num_carry=num_refs + num_carry,
|
||||
num_consts=n_val_consts,
|
||||
num_carry=n_ref_consts + n_carry,
|
||||
unroll=unroll,
|
||||
reverse=reverse,
|
||||
linear=new_linear, _split_transpose=_split_transpose)
|
||||
refs_out, carry_out, ys_out = split_list(all_out, [num_refs, num_carry])
|
||||
new_invals = [*merge_lists(is_ref, [None] * num_remaining_consts, refs_out),
|
||||
*[None] * num_carry, *[None] * num_extensive_in]
|
||||
assert len(new_invals) == len(in_avals)
|
||||
return new_invals, [*carry_out, *ys_out]
|
||||
linear=linear_for_wrapped, _split_transpose=_split_transpose)
|
||||
ref_consts_out, carry_out, ys, ref_xs_out = split_list_checked(all_out,
|
||||
[n_ref_consts, n_carry, n_ys, n_ref_xs])
|
||||
refs_out_matching_in_avals = [
|
||||
*merge_lists(is_ref_const, [None] * n_val_consts, ref_consts_out),
|
||||
*[None] * n_carry,
|
||||
*merge_lists(is_ref_xs, [None] * n_val_xs, ref_xs_out)]
|
||||
assert len(refs_out_matching_in_avals) == len(in_avals)
|
||||
return refs_out_matching_in_avals, [*carry_out, *ys]
|
||||
|
||||
def scan_bind(*args, **params):
|
||||
if config.enable_checks.value:
|
||||
|
@ -112,10 +112,10 @@ def _eval_jaxpr_discharge_state(
|
||||
for eqn in jaxpr.eqns:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
[invar], [outvar] = eqn.invars, eqn.outvars
|
||||
init_val = env.read(invar)
|
||||
env.write(outvar, init_val)
|
||||
ans = env.read(invar)
|
||||
refs_to_discharge.add(id(outvar.aval))
|
||||
elif any(id(v.aval) in refs_to_discharge for v in eqn.invars):
|
||||
elif (any(id(v.aval) in refs_to_discharge for v in eqn.invars)
|
||||
or core.internal_mutable_array_effect in eqn.effects ):
|
||||
if eqn.primitive not in _discharge_rules:
|
||||
raise NotImplementedError("No state discharge rule implemented for "
|
||||
f"primitive: {eqn.primitive}")
|
||||
|
@ -131,6 +131,15 @@ def split_list(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
||||
lists.append(args)
|
||||
return lists
|
||||
|
||||
def split_list_checked(args: Sequence[T], ns: Sequence[int]) -> list[list[T]]:
|
||||
args = list(args)
|
||||
assert sum(ns) == len(args)
|
||||
lists = []
|
||||
for n in ns:
|
||||
lists.append(args[:n])
|
||||
args = args[n:]
|
||||
return lists
|
||||
|
||||
def partition_list(bs: Sequence[bool], l: Sequence[T]) -> tuple[list[T], list[T]]:
|
||||
assert len(bs) == len(l)
|
||||
lists = [], [] # type: ignore
|
||||
|
@ -23,6 +23,7 @@ from jax._src.util import (
|
||||
safe_zip as safe_zip,
|
||||
split_dict as split_dict,
|
||||
split_list as split_list,
|
||||
split_list_checked as split_list_checked,
|
||||
split_merge as split_merge,
|
||||
subvals as subvals,
|
||||
toposort as toposort,
|
||||
|
@ -1289,6 +1289,11 @@ jax_test(
|
||||
deps = py_deps("hypothesis"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "mutable_array_test",
|
||||
srcs = ["mutable_array_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "for_loop_test",
|
||||
srcs = ["for_loop_test.py"],
|
||||
|
227
tests/mutable_array_test.py
Normal file
227
tests/mutable_array_test.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax._src.state.types import (RefEffect)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
class MutableArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_basic(self, jit):
|
||||
def f(x_mut):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
f(x_mut)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(x_mut)
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
# disabling this test for now. TODO(dougalm): re-enable once we add checks to
|
||||
# ensure mutable arrays aren't returned or duplicated etc.
|
||||
# def test_staging_error(self):
|
||||
# x = jnp.zeros(3)
|
||||
# with self.assertRaises(Exception):
|
||||
# jax.jit(core.mutable_array)(x)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_multiple_inputs_and_outputs(self, jit):
|
||||
def f(x_mut, y, z_mut, w):
|
||||
x_mut[...] += 1
|
||||
z_mut[...] += 1
|
||||
return x_mut[...] + y + z_mut[...] + w, y + w
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros((1, 3)))
|
||||
y = jnp.ones((2, 3))
|
||||
z_mut = core.mutable_array(jnp.zeros((2, 3)))
|
||||
w = jnp.ones((2, 1))
|
||||
|
||||
out1, out2 = f(x_mut, y, z_mut, w)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
|
||||
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out2, y + w, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_basic(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
def f():
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
f()
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)()
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_nested(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
|
||||
@jax.jit
|
||||
def f(y_mut, z):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
y_mut[2] += 7
|
||||
return z + 9
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
y_mut = core.mutable_array(np.zeros(3))
|
||||
|
||||
w = f(y_mut, 1)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(w, 10, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_internal_mutarray_basic(self, jit):
|
||||
def f():
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
x_mut[0] += 1
|
||||
x_mut[0] += 1
|
||||
x_mut[2] += 1
|
||||
return x_mut[...]
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f()
|
||||
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_refs_in_vjps(self, jit):
|
||||
def gradient_history_calculator_fwd(x, ref):
|
||||
return x, ref
|
||||
|
||||
def gradient_history_calculator_bwd(amax_history, grad_output):
|
||||
amax_update = jnp.max(jnp.abs(grad_output))
|
||||
shifted = jnp.roll(amax_history[:], 1)
|
||||
shifted = shifted.at[0].set(amax_update)
|
||||
amax_history[:] = shifted
|
||||
amax_from_history = jnp.max(amax_history[:])
|
||||
grad_output = grad_output / amax_from_history
|
||||
return grad_output, None
|
||||
|
||||
@jax.custom_vjp
|
||||
def gradient_history_calculator(x, ref):
|
||||
return x
|
||||
|
||||
gradient_history_calculator.defvjp(
|
||||
gradient_history_calculator_fwd,
|
||||
gradient_history_calculator_bwd)
|
||||
|
||||
class DotOp:
|
||||
def __init__(self):
|
||||
self.amax_history = core.mutable_array(jnp.zeros(5,))
|
||||
|
||||
def forward(self, x, y):
|
||||
out = jnp.dot(x, y)
|
||||
out = gradient_history_calculator(out, self.amax_history)
|
||||
return out
|
||||
|
||||
dot_op = DotOp()
|
||||
x_top = jnp.ones((5,))
|
||||
y_top = jnp.ones((5,))
|
||||
|
||||
def loss(x, y):
|
||||
return dot_op.forward(x, y).sum()
|
||||
|
||||
if jit:
|
||||
loss = jax.jit(loss)
|
||||
|
||||
for i in range(3):
|
||||
jax.grad(loss, (0,1))(x_top, y_top)
|
||||
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_scan_internal_mut_array(self, jit):
|
||||
def body_fun(_, x):
|
||||
x_mut = core.mutable_array(x)
|
||||
x_mut[...] += 2
|
||||
return ((), x_mut[...])
|
||||
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
|
||||
if jit:
|
||||
doit = jax.jit(doit)
|
||||
_, xs = doit()
|
||||
self.assertAllClose(xs, (np.arange(5) + 2), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_scan_closed_over_mut_array(self, jit):
|
||||
x_mut = core.mutable_array(0)
|
||||
def body_fun(_, x):
|
||||
x_mut[...] += 2
|
||||
return ((), x_mut[...])
|
||||
|
||||
doit = lambda: jax.lax.scan(body_fun, (), np.arange(5))
|
||||
if jit:
|
||||
doit = jax.jit(doit)
|
||||
_, xs = doit()
|
||||
self.assertAllClose(x_mut[...], 10)
|
||||
self.assertAllClose(xs, np.arange(5) * 2 + 2, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_scan_scanned_mut_array(self, jit):
|
||||
def body_fun(_, index_x):
|
||||
(index, x) = index_x
|
||||
x[...] += index
|
||||
# breakpoint()
|
||||
return ((), x[...])
|
||||
|
||||
x_mut = core.mutable_array(np.arange(5))
|
||||
doit = lambda: jax.lax.scan(body_fun, (), (np.arange(5), x_mut))
|
||||
if jit:
|
||||
doit = jax.jit(doit)
|
||||
_, xs = doit()
|
||||
self.assertAllClose(xs, (np.arange(5) * 2), check_dtypes=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -48,7 +48,7 @@ from jax._src.state.primitives import (get_p, swap_p, addupdate_p,
|
||||
ref_addupdate, ref_get, ref_set,
|
||||
ref_swap)
|
||||
from jax._src.state.types import (shaped_array_ref, ReadEffect, WriteEffect,
|
||||
AccumEffect, RefEffect, AbstractRef)
|
||||
AccumEffect, AbstractRef)
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -1495,159 +1495,6 @@ class RunStateTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (0.5,), order=3)
|
||||
|
||||
|
||||
class MutableArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_basic(self, jit):
|
||||
def f(x_mut):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
f(x_mut)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(x_mut)
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
def test_staging_error(self):
|
||||
x = jnp.zeros(3)
|
||||
with self.assertRaises(Exception):
|
||||
jax.jit(core.mutable_array)(x)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_multiple_inputs_and_outputs(self, jit):
|
||||
def f(x_mut, y, z_mut, w):
|
||||
x_mut[...] += 1
|
||||
z_mut[...] += 1
|
||||
return x_mut[...] + y + z_mut[...] + w, y + w
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros((1, 3)))
|
||||
y = jnp.ones((2, 3))
|
||||
z_mut = core.mutable_array(jnp.zeros((2, 3)))
|
||||
w = jnp.ones((2, 1))
|
||||
|
||||
out1, out2 = f(x_mut, y, z_mut, w)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
|
||||
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out2, y + w, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_basic(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
def f():
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
f()
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)()
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_closed_over_nested(self, jit):
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
|
||||
@jax.jit
|
||||
def f(y_mut, z):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
y_mut[2] += 7
|
||||
return z + 9
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
y_mut = core.mutable_array(np.zeros(3))
|
||||
|
||||
w = f(y_mut, 1)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(w, 10, check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_internal_mutarray_basic(self, jit):
|
||||
def f():
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
x_mut[0] += 1
|
||||
x_mut[0] += 1
|
||||
x_mut[2] += 1
|
||||
return x_mut[...]
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f()
|
||||
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_refs_in_vjps(self, jit):
|
||||
def gradient_history_calculator_fwd(x, ref):
|
||||
return x, ref
|
||||
|
||||
def gradient_history_calculator_bwd(amax_history, grad_output):
|
||||
amax_update = jnp.max(jnp.abs(grad_output))
|
||||
shifted = jnp.roll(amax_history[:], 1)
|
||||
shifted = shifted.at[0].set(amax_update)
|
||||
amax_history[:] = shifted
|
||||
amax_from_history = jnp.max(amax_history[:])
|
||||
grad_output = grad_output / amax_from_history
|
||||
return grad_output, None
|
||||
|
||||
@jax.custom_vjp
|
||||
def gradient_history_calculator(x, ref):
|
||||
return x
|
||||
|
||||
gradient_history_calculator.defvjp(
|
||||
gradient_history_calculator_fwd,
|
||||
gradient_history_calculator_bwd)
|
||||
|
||||
class DotOp:
|
||||
def __init__(self):
|
||||
self.amax_history = core.mutable_array(jnp.zeros(5,))
|
||||
|
||||
def forward(self, x, y):
|
||||
out = jnp.dot(x, y)
|
||||
out = gradient_history_calculator(out, self.amax_history)
|
||||
return out
|
||||
|
||||
dot_op = DotOp()
|
||||
x_top = jnp.ones((5,))
|
||||
y_top = jnp.ones((5,))
|
||||
|
||||
def loss(x, y):
|
||||
return dot_op.forward(x, y).sum()
|
||||
|
||||
if jit:
|
||||
loss = jax.jit(loss)
|
||||
|
||||
for i in range(3):
|
||||
jax.grad(loss, (0,1))(x_top, y_top)
|
||||
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
class FuncSpec(NamedTuple):
|
||||
|
Loading…
x
Reference in New Issue
Block a user