mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[mutable-arrays] move MutableArray, add eager, improve tests, fix bug
1. move MutableArray to core.py, and some handlers to their respective files 2. fix a bug in aliasing setup (it was just broken before, now better test coverage) 3. add eager support by enabling get_p, swap_p, and addupdate_p impls 4. improve tests slightly
This commit is contained in:
parent
2761f266d5
commit
3a403f2a0e
@ -1912,6 +1912,30 @@ class bint(dtypes.ExtendedDType):
|
||||
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]
|
||||
|
||||
|
||||
class MutableArray:
|
||||
_aval: ShapedArray
|
||||
_buf: Array
|
||||
def __init__(self, aval, buf):
|
||||
self._aval = aval
|
||||
self._buf = buf
|
||||
aval = property(lambda self: self._aval)
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
|
||||
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
|
||||
pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
mutable_array_p = Primitive('mutable_array')
|
||||
|
||||
@mutable_array_p.def_impl
|
||||
def _mutable_array_impl(init_val):
|
||||
from jax._src.state.types import AbstractRef # type: ignore[import]
|
||||
aval = raise_to_shaped(get_aval(init_val))
|
||||
return MutableArray(AbstractRef(aval), init_val)
|
||||
|
||||
|
||||
class AbstractToken(AbstractValue):
|
||||
def join(self, other):
|
||||
if isinstance(other, AbstractToken):
|
||||
|
@ -2730,6 +2730,13 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
|
||||
return prim.bind(*subfuns, *args, **bind_params)
|
||||
|
||||
|
||||
def _error_staging_mutable_array_p(trace, x):
|
||||
raise Exception(
|
||||
"mutable_array constructor can't be staged out, and in particular can't "
|
||||
"be used under a jax.jit or jax.lax.scan")
|
||||
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p
|
||||
|
||||
|
||||
# TODO(mattjj): the following are deprecated; update callers to _nounits version
|
||||
# See https://github.com/google/jax/pull/9498
|
||||
@lu.transformation
|
||||
|
@ -160,6 +160,10 @@ def _shard_darray(x, sharding):
|
||||
return shard_arg(x._data, sharding)
|
||||
shard_arg_handlers[core.DArray] = _shard_darray
|
||||
|
||||
def _shard_mutable_array(x, sharding):
|
||||
return shard_arg(x._buf, sharding)
|
||||
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
|
||||
|
||||
def batched_device_put(aval: core.ShapedArray,
|
||||
sharding: jax.sharding.Sharding, xs: Sequence[Any],
|
||||
devices: Sequence[jax.Device], committed: bool = True):
|
||||
@ -1778,17 +1782,16 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
|
||||
@weakref_lru_cache
|
||||
def _discharge_refs(
|
||||
jaxpr: core.ClosedJaxpr
|
||||
) -> tuple[core.ClosedJaxpr, None | Sequence[int | None], None | Sequence[int | None]]:
|
||||
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
|
||||
from jax._src.state.discharge import discharge_state
|
||||
out_mut = [None] * len(jaxpr.out_avals) + [
|
||||
i for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)]
|
||||
count = it.count()
|
||||
inout_aliases = tuple(next(count) if isinstance(a, AbstractRef) else None
|
||||
for a in jaxpr.in_avals)
|
||||
jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
|
||||
assert len(inout_aliases) == len(jaxpr.in_avals)
|
||||
assert len(out_mut) == len(jaxpr.out_avals)
|
||||
return jaxpr, inout_aliases, out_mut
|
||||
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
|
||||
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
|
||||
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
|
||||
if isinstance(a, AbstractRef)}
|
||||
outin_map = {j: i for i, j in inout_map.items()}
|
||||
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
|
||||
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
|
||||
return new_jaxpr, inout_aliases, out_mut
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -22,7 +22,6 @@ import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
import operator
|
||||
from typing import Any, Callable, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
@ -166,6 +165,7 @@ canonicalize_dtype_handlers.update(
|
||||
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
|
||||
canonicalize_dtype_handlers[core.Token] = identity
|
||||
canonicalize_dtype_handlers[core.DArray] = identity
|
||||
canonicalize_dtype_handlers[core.MutableArray] = identity
|
||||
|
||||
def abstractify(x) -> Any:
|
||||
typ = type(x)
|
||||
@ -196,7 +196,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
|
||||
|
||||
|
||||
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
|
||||
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
|
||||
pytype_aval_mappings[core.DArray] = lambda x: x._aval
|
||||
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
|
||||
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
|
||||
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
|
||||
for t in numpy_scalar_types)
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
@ -53,11 +54,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
|
||||
def _get_impl(ref: AbstractRef, *args: Any, tree):
|
||||
del ref, args, tree
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
get_p.def_impl(_get_impl)
|
||||
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
|
||||
|
||||
Indexer = tuple[Union[int, slice, Array], ...]
|
||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||
@ -113,11 +110,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
|
||||
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
|
||||
# x:Ref{f32[3]}[i, j] <- a
|
||||
swap_p = core.Primitive("swap")
|
||||
|
||||
def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
|
||||
del ref, value, idx, tree
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
swap_p.def_impl(_swap_impl)
|
||||
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
|
||||
|
||||
def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array,
|
||||
_function_name: str = "ref_swap") -> Array:
|
||||
@ -143,11 +136,7 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra
|
||||
# ```
|
||||
addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
|
||||
def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
|
||||
del ref, value, args, tree
|
||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||
addupdate_p.def_impl(_addupdate_impl)
|
||||
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))
|
||||
|
||||
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
|
@ -55,18 +55,6 @@ config.parse_flags_with_absl()
|
||||
|
||||
class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
|
||||
def test_cant_eval_get_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
get_p.bind(jnp.ones(5), tree=None)
|
||||
|
||||
def test_cant_eval_swap_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||
|
||||
def test_cant_eval_addupdate_primitive(self):
|
||||
with self.assertRaises(ValueError):
|
||||
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
|
||||
|
||||
def test_get_abstract_aval_must_take_in_refs(self):
|
||||
ref_aval = core.ShapedArray((), jnp.float32)
|
||||
def f(x_ref):
|
||||
@ -1508,55 +1496,54 @@ class RunStateTest(jtu.JaxTestCase):
|
||||
jtu.check_grads(f, (0.5,), order=3)
|
||||
|
||||
|
||||
class MutableArray:
|
||||
_aval: core.ShapedArray
|
||||
_buf: jax.Array
|
||||
def __init__(self, aval, buf):
|
||||
self._aval = aval
|
||||
self._buf = buf
|
||||
aval = property(lambda self: self._aval)
|
||||
shape = property(lambda self: self._aval.shape)
|
||||
dtype = property(lambda self: self._aval.dtype)
|
||||
|
||||
def mutable_array(init_val):
|
||||
return mutable_array_p.bind(init_val)
|
||||
mutable_array_p = core.Primitive('mutable_array')
|
||||
|
||||
@mutable_array_p.def_impl
|
||||
def _mutable_array_impl(init_val):
|
||||
aval = core.raise_to_shaped(core.get_aval(init_val))
|
||||
return MutableArray(AbstractRef(aval), init_val)
|
||||
|
||||
def _error_on_staging(trace, x):
|
||||
raise Exception
|
||||
pe.custom_staging_rules[mutable_array_p] = _error_on_staging
|
||||
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
xla.canonicalize_dtype_handlers[MutableArray] = lambda x: x
|
||||
xla.pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
pxla.shard_arg_handlers[MutableArray] = lambda x, s: pxla.shard_arg(x._buf, s)
|
||||
core.pytype_aval_mappings[MutableArray] = lambda x: x._aval
|
||||
|
||||
class MutableArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_basic(self):
|
||||
read = jax.jit(lambda x_ref: x_ref[...])
|
||||
|
||||
@jax.jit
|
||||
@parameterized.parameters([True, False])
|
||||
def test_basic(self, jit):
|
||||
def f(x_mut):
|
||||
x_mut[...] += 1.
|
||||
x_mut[0] += 1
|
||||
x_mut[1] += 5
|
||||
|
||||
x_mut = mutable_array(jnp.zeros(3))
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros(3))
|
||||
f(x_mut)
|
||||
|
||||
self.assertAllClose(read(x_mut), jnp.array([2., 6., 1.]), check_dtypes=False)
|
||||
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
|
||||
check_dtypes=False)
|
||||
|
||||
jaxpr = jax.make_jaxpr(f)(x_mut)
|
||||
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
|
||||
|
||||
def test_staging_error(self):
|
||||
x = jnp.zeros(3)
|
||||
with self.assertRaises(Exception):
|
||||
jax.jit(core.mutable_array)(x)
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
def test_multiple_inputs_and_outputs(self, jit):
|
||||
def f(x_mut, y, z_mut, w):
|
||||
x_mut[...] += 1
|
||||
z_mut[...] += 1
|
||||
return x_mut[...] + y + z_mut[...] + w, y + w
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x_mut = core.mutable_array(jnp.zeros((1, 3)))
|
||||
y = jnp.ones((2, 3))
|
||||
z_mut = core.mutable_array(jnp.zeros((2, 3)))
|
||||
w = jnp.ones((2, 1))
|
||||
|
||||
out1, out2 = f(x_mut, y, z_mut, w)
|
||||
|
||||
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
|
||||
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
|
||||
self.assertAllClose(out2, y + w, check_dtypes=False)
|
||||
|
||||
|
||||
if CAN_USE_HYPOTHESIS:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user