Refactor Ref abstract type to contain other AbstractValues

This commit is contained in:
Sharad Vikram 2023-02-17 12:45:39 -08:00
parent 8d0bdd2670
commit 4960e656af
7 changed files with 268 additions and 174 deletions

View File

@ -237,7 +237,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, state.ShapedArrayRef) for op_aval in ops_avals):
if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees
@ -866,5 +866,5 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, state.ShapedArrayRef) else None)
next(ref_val_iter) if isinstance(aval, state.AbstractRef) else None)
return new_invals, out_vals

View File

@ -54,7 +54,7 @@ ReadEffect = state.ReadEffect
WriteEffect = state.WriteEffect
AccumEffect = state.AccumEffect
StateEffect = state.StateEffect
ShapedArrayRef = state.ShapedArrayRef
AbstractRef = state.AbstractRef
ref_set = state.ref_set
ref_get = state.ref_get
ref_addupdate = state.ref_addupdate
@ -70,10 +70,10 @@ for_p.multiple_results = True
def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
all_const_avals = [var.aval for var in jaxpr.constvars]
is_const_ref = [isinstance(var.aval, ShapedArrayRef) for var in
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
jaxpr.constvars]
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
const_avals = [ShapedArrayRef(aval.shape, aval.dtype) for aval in const_avals] # pytype: disable=attribute-error
const_avals = map(AbstractRef, const_avals)
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
i_aval, *arg_avals = [var.aval for var in jaxpr.invars]
in_avals = [i_aval, *merged_const_avals, *arg_avals]
@ -100,11 +100,11 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
f, state_avals)
return jaxpr, consts, out_tree_thunk()
def val_to_ref_aval(x) -> ShapedArrayRef:
def val_to_ref_aval(x) -> AbstractRef:
aval = core.raise_to_shaped(core.get_aval(x))
if type(aval) is not core.ShapedArray:
raise Exception(f"can't make ref from {x}")
return ShapedArrayRef(aval.shape, aval.dtype)
return AbstractRef(aval)
def for_loop(nsteps: Union[int, Sequence[int]],
body: Callable[[Array, Ref[S]], None], init_state: S,
@ -252,7 +252,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
aval_effects = [set(eff.replace(input_index=eff.input_index - 1)
for eff in effs) for aval, effs
in zip(avals, jaxpr_aval_effects)
if isinstance(aval, ShapedArrayRef)]
if isinstance(aval, AbstractRef)]
nonlocal_state_effects = core.join_effects(*aval_effects)
return list(avals), nonlocal_state_effects
@ -266,7 +266,7 @@ def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
unroll=unroll)
new_invals = []
for aval, out_val in zip(in_avals, out_vals):
new_invals.append(out_val if isinstance(aval, ShapedArrayRef) else None)
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
return new_invals, out_vals
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
@ -661,10 +661,11 @@ def _convert_outputs_to_writes(
res_ref[i] = res_val
return []
# TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype
res_ref_avals = [ShapedArrayRef(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error
if loop_invar else
ShapedArrayRef((nsteps, *v.aval.shape), v.aval.dtype) # pytype: disable=attribute-error
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
res_ref_avals: List[core.AbstractValue] = [
AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
v.aval.dtype)) # pytype: disable=attribute-error
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [*in_avals, *res_ref_avals])
assert not consts
@ -685,9 +686,11 @@ def _convert_inputs_to_reads(
res_val_avals, (i_aval,), orig_ref_avals = \
split_list([v.aval for v in jaxpr.invars], [num_res, 1])
res_ref_avals = [ShapedArrayRef(aval.shape, aval.dtype) if loop_invar else
ShapedArrayRef((nsteps, *aval.shape), aval.dtype) # pytype: disable=attribute-error
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
res_ref_avals: List[core.AbstractValue] = [
AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error
AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error
aval.dtype)) # pytype: disable=attribute-error
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])

View File

@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for state."""
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect, StateEffect, RefEffect,
get_ref_state_effects)
get_ref_state_effects, shaped_array_ref)
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)

View File

@ -25,7 +25,7 @@ from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import linear_util as lu
from jax._src.state.types import ShapedArrayRef
from jax._src.state.types import AbstractRef
from jax._src.state.primitives import get_p, swap_p, addupdate_p
from jax._src.util import safe_map, safe_zip, split_list
@ -47,8 +47,8 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
if isinstance(should_discharge, bool):
should_discharge = [should_discharge] * len(jaxpr.invars)
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
if type(v.aval) is ShapedArrayRef and d
in_avals = [v.aval.inner_aval
if type(v.aval) is AbstractRef and d
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
should_discharge, consts))
@ -83,7 +83,7 @@ def register_discharge_rule(prim: core.Primitive):
return register
def _has_refs(eqn: core.JaxprEqn):
return any(isinstance(v.aval, ShapedArrayRef) for v in eqn.invars)
return any(isinstance(v.aval, AbstractRef) for v in eqn.invars)
def _eval_jaxpr_discharge_state(
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
@ -97,7 +97,7 @@ def _eval_jaxpr_discharge_state(
refs_to_discharge = set(id(v.aval) for v, d
in zip(jaxpr.invars, should_discharge) if d
and isinstance(v.aval, ShapedArrayRef))
and isinstance(v.aval, AbstractRef))
for eqn in jaxpr.eqns:
if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge
@ -237,7 +237,7 @@ def _closed_call_discharge_rule(
call_jaxpr=discharged_closed_jaxpr)
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
ref_vals_iter = iter(ref_vals)
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, ShapedArrayRef)
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
else None for aval in in_avals)
assert next(ref_vals_iter, None) is None
return new_invals, out_vals

View File

@ -14,7 +14,7 @@
"""Module for state primitives."""
from functools import partial
from typing import Any, List, Protocol, Tuple, TypeVar, Union
from typing import Any, List, Tuple, Union
import numpy as np
@ -27,24 +27,13 @@ from jax._src import ad_util
from jax._src import core
from jax._src import pretty_printer as pp
from jax._src.typing import Array
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
AccumEffect)
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
## General utilities
T = TypeVar('T')
class Ref(Protocol):
@property
def shape(self) -> Tuple[int, ...]:
...
@property
def ndim(self) -> int:
...
## JAX utilities
map, unsafe_map = safe_map, map
@ -57,11 +46,11 @@ zip, unsafe_zip = safe_zip, zip
# or we can read using indices:
# a = get_p.bind(x, 0, 1)
# Staging out `a = get_p.bind(x)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# a:f32[3] <- x[]
get_p = core.Primitive("get")
def _get_impl(ref: Ref, *idx: int, **_):
def _get_impl(ref: AbstractRef, *idx: int, **_):
del ref, idx
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)
@ -69,9 +58,18 @@ get_p.def_impl(_get_impl)
Indexer = Tuple[Union[int, slice, Array], ...]
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
def _is_trivial_indexer(idx: Indexer) -> bool:
if idx is ...:
return True
if type(idx) is tuple:
if len(idx) == 0:
return True
return len(idx) == 1 and idx[0] is ...
return False
def _unpack_idx(idx: Indexer, ndim: int
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
if idx is ... or (type(idx) is tuple and len(idx) == 1 and idx[0] is ...):
if _is_trivial_indexer(idx):
idx = tuple(slice(None) for _ in range(ndim))
indexed_dims_ = [type(i) != slice for i in idx]
_, non_slice_idx = partition_list(indexed_dims_, idx)
@ -88,9 +86,23 @@ def _get_slice_output_shape(in_shape: Tuple[int, ...],
shape = (*shape_prefix, *shape_suffix)
return shape
def ref_get(ref: Ref, idx: Indexer) -> Array:
def _get_indexer(ref: AbstractRef, idx: Indexer
) -> Tuple[Indexer, Tuple[bool, ...]]:
if isinstance(ref.inner_aval, core.ShapedArray):
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
else:
if not _is_trivial_indexer(idx):
raise ValueError(
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
non_slice_idx, indexed_dims = (), ()
return non_slice_idx, indexed_dims
def ref_get(ref: Any, idx: Indexer) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
# `swap` mutates a `Ref`, setting its value and returns its previous value.
@ -101,27 +113,30 @@ def ref_get(ref: Ref, idx: Indexer) -> Array:
# `swap_p` also takes in index arguments following the value, i.e.:
# _ = swap_p.bind(x, a, 0, 1)
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
# `Ref((3,), np.dtype('float32'))` and the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
# b:f32[3], x:Ref{f32[3]} <- x, a
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
# `Ref((3,), np.dtype('float32'))` , the aval of `a` is
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
# x:Ref{f32[3]}[i, j] <- a
swap_p = core.Primitive("swap")
def _swap_impl(ref: Ref, value: Array, *idx: int, **_):
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
del ref, value, idx
raise ValueError("Cannot run stateful primitive.")
swap_p.def_impl(_swap_impl)
def ref_swap(ref: Ref, idx: Indexer, value: Array) -> Array:
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
def ref_set(ref: Ref, idx: Indexer, value: Array) -> None:
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
ref_swap(ref, idx, value)
@ -139,81 +154,108 @@ def ref_set(ref: Ref, idx: Indexer, value: Array) -> None:
addupdate_p = core.Primitive('addupdate')
addupdate_p.multiple_results = True
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
del ref, idx, value
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
addupdate_p.def_impl(_addupdate_impl)
def ref_addupdate(ref: Ref, idx: Indexer, x: Array) -> None:
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
ref_aval = core.get_aval(ref)
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
## get/set/addupdate abstract evaluation rules
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx, indexed_dims):
if not isinstance(ref_aval, ShapedArrayRef):
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
indexed_dims):
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
return (core.ShapedArray(shape, ref_aval.dtype), {ReadEffect(0)})
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`get` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
idx_shapes = tuple(i.shape for i in idx)
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
out_aval = ref_aval.inner_aval.update(shape=shape)
else:
if idx:
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
out_aval = ref_aval.inner_aval
return (out_aval, {ReadEffect(0)})
get_p.def_effectful_abstract_eval(_get_abstract_eval)
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
def _swap_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
if not isinstance(ref_aval, ShapedArrayRef):
out_aval: core.AbstractValue
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
expected_output_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
return (core.ShapedArray(expected_output_shape, ref_aval.dtype),
{WriteEffect(0)})
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
expected_output_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if expected_output_shape != val_aval.shape:
raise ValueError("Invalid shape for `swap`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `swap`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
else:
if idx:
raise ValueError("`swap` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
out_aval = ref_aval.inner_aval
return (out_aval, {WriteEffect(0)})
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
def _addupdate_abstract_eval(ref_aval: AbstractRef,
val_aval: core.AbstractValue,
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
if not isinstance(ref_aval, ShapedArrayRef):
if not isinstance(ref_aval, AbstractRef):
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
slice_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if slice_shape != val_aval.shape:
raise ValueError("Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
if isinstance(ref_aval.inner_aval, core.ShapedArray):
if len(indexed_dims) != len(ref_aval.shape):
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
if sum(indexed_dims) != len(idx):
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
val_aval = core.raise_to_shaped(val_aval)
assert isinstance(val_aval, core.ShapedArray)
idx_shapes = tuple(i.shape for i in idx)
slice_shape = _get_slice_output_shape(
ref_aval.shape, idx_shapes, indexed_dims)
if slice_shape != val_aval.shape:
raise ValueError("Invalid shape for `addupdate`. "
f"Ref shape: {ref_aval.shape}. "
f"Value shape: {val_aval.shape}. "
f"Indices: {idx}. ")
if ref_aval.dtype != val_aval.dtype:
raise ValueError("Invalid dtype for `addupdate`. "
f"Ref dtype: {ref_aval.dtype}. "
f"Value shape: {val_aval.dtype}. ")
elif idx:
raise ValueError("`addupdate` with nontrivial indexing must be called "
f"on `ShapedArray` `Ref`: {ref_aval}.")
return [], {AccumEffect(0)}
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
@ -279,18 +321,18 @@ core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
ref_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
assert isinstance(ref_tangent.aval, AbstractRef)
return (get_p.bind(ref_primal, *idx, **params),
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
ad.primitive_jvps[get_p] = _get_jvp
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
ref_primal, x_primal, *idx = primals
assert isinstance(ref_primal.aval, ShapedArrayRef)
assert isinstance(ref_primal.aval, AbstractRef)
ref_tangent, x_tangent, *_ = tangents
assert isinstance(ref_tangent.aval, ShapedArrayRef)
assert isinstance(ref_tangent.aval, AbstractRef)
x_tangent = ad_util.instantiate(x_tangent)
return (swap_p.bind(ref_primal, x_primal, *idx, **params), # type: ignore[arg-type]
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]

View File

@ -14,13 +14,13 @@
"""Module for state types."""
from __future__ import annotations
from typing import Any, List, Sequence, Set, Union
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
from jax._src import core
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src.lib import xla_bridge, xla_client
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
from jax._src.util import safe_map, safe_zip, prod
xc = xla_client
xb = xla_bridge
@ -74,23 +74,34 @@ StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
# ## `Ref`s
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
class ShapedArrayRef(core.AbstractValue):
__slots__ = ["shape", "dtype"]
Aval = TypeVar("Aval", bound=core.AbstractValue)
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
class AbstractRef(core.AbstractValue, Generic[Aval]):
__slots__ = ["inner_aval"]
def __init__(self, inner_aval: core.AbstractValue):
self.inner_aval = inner_aval
def join(self, other):
assert core.symbolic_equal_shape(self.shape, other.shape)
assert self.dtype == other.dtype
return self
assert isinstance(other, AbstractRef)
return AbstractRef(self.inner_aval.join(other.inner_aval))
ndim = property(lambda self: len(self.shape))
size = property(lambda self: prod(self.shape))
@property
def shape(self):
if not isinstance(self.inner_aval, core.ShapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
return self.inner_aval.shape
@property
def dtype(self):
if not isinstance(self.inner_aval, core.UnshapedArray):
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
return self.inner_aval.dtype
@core.aval_method
@staticmethod
def get(tracer, idx=()):
@ -116,35 +127,39 @@ class ShapedArrayRef(core.AbstractValue):
return ref_set(tracer, idx, value)
def __repr__(self) -> str:
a = core.ShapedArray(self.shape, self.dtype)
return f'Ref{{{a.str_short()}}}'
return f'Ref{{{self.inner_aval.str_short()}}}'
def at_least_vspace(self):
return self
return AbstractRef(self.inner_aval.at_least_vspace())
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape)
return (type(self) is type(other) and self.inner_aval == other.inner_aval)
def __hash__(self):
return hash((self.shape, self.dtype))
return hash((self.__class__, self.inner_aval))
def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type):
return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type))
core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval
def _map_ref(size, axis, ref_aval):
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
def _map_ref(size, axis, aval):
if axis is None: return aval
return ShapedArrayRef(tuple_delete(aval.shape, axis), aval.dtype)
def _unmap_ref(size, axis_name, axis, ref_aval):
return AbstractRef(core.unmapped_aval(size, axis_name, axis,
ref_aval.inner_aval))
def _unmap_ref(size, axis_name, axis, aval):
if axis is None: return aval
return ShapedArrayRef(tuple_insert(aval.shape, axis, size), aval.dtype)
core.aval_mapping_handlers[ShapedArrayRef] = (_map_ref, _unmap_ref)
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
def get_ref_state_effects(
avals: Sequence[core.AbstractValue],
effects: core.Effects) -> List[Set[StateEffect]]:
return [{eff for eff in effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.input_index == i} for i, aval in enumerate(avals)]
and eff.input_index == i} for i, _ in enumerate(avals)]
def shaped_array_ref(shape: Tuple[int, ...], dtype,
weak_type: bool = False,
named_shape = None) -> AbstractRef[core.ShapedArray]:
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
named_shape=named_shape))

View File

@ -88,7 +88,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
)
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
out_dtype=None, should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
def f(x_ref):
out = state.ref_get(x_ref, idx)
return [out]
@ -154,7 +154,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
out = state.ref_swap(x_ref, idx, val)
@ -210,7 +210,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
should_error=False):
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
val_aval = core.ShapedArray(val_shape, val_dtype)
def f(x_ref, val):
state.ref_addupdate(x_ref, idx, val)
@ -240,7 +240,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x[()] = jnp.int32(2)
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
@ -253,7 +253,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
state.ref_addupdate(x, (), jnp.int32(1))
return (x[()],)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertLen(consts, 0)
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
@ -263,14 +263,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = x_ref[()]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
def body(x_ref):
x = x_ref[:, 0]
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32)])
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
def test_set_custom_pretty_printing_rule(self):
@ -278,14 +278,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x_ref[()] = jnp.int32(2)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x_ref[:, 0] = val
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
@ -294,14 +294,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
x = state.ref_swap(x_ref, (), jnp.int32(2))
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
def body(x_ref, val):
x = state.ref_swap(x_ref, (slice(None), 0), val)
return [x]
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
jaxpr.pretty_print(use_color=False))
@ -311,7 +311,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
state.ref_addupdate(x_ref, (), jnp.int32(2))
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
@ -319,7 +319,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
state.ref_addupdate(x_ref, (slice(None), 0), val)
return []
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
core.ShapedArray((1,), jnp.int32)])
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
@ -333,8 +333,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
state.shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
@ -349,8 +349,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
state.shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
@ -369,8 +369,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
def g(r, rdot):
return jax.jvp(f, (r,), (rdot,))
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
state.ShapedArrayRef((), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
state.shaped_array_ref((), jnp.dtype('float32'))]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
@ -420,8 +420,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
return tuple_insert(shape, idx, axis_size)
batched_ref_shape = maybe_insert(ref_shape, ref_bdim)
ref_aval = state.ShapedArrayRef(ref_shape, float_)
bat_ref_aval = state.ShapedArrayRef(batched_ref_shape, float_)
ref_aval = state.shaped_array_ref(ref_shape, float_)
bat_ref_aval = state.shaped_array_ref(batched_ref_shape, float_)
idx_avals = [core.ShapedArray(idx_shape, int_)
for _ in idx_bdims]
@ -465,7 +465,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref):
a = state.ref_get(a_ref, ())
return [a + 1]
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
@ -481,7 +481,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref):
a = state.ref_get(a_ref, (0, 1))
return [a + 1]
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
@ -500,7 +500,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref):
a = a_ref[jnp.array([0, 1])]
return [a + 1]
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
@ -514,7 +514,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref, b):
state.ref_set(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
@ -532,7 +532,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref):
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
# Discharging should just turn this into a jaxpr that just adds 1.
@ -552,7 +552,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref):
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
return []
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
@ -565,7 +565,7 @@ class StateDischargeTest(jtu.JaxTestCase):
def f(a_ref, b):
state.ref_addupdate(a_ref, (), b + 1)
return []
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
core.ShapedArray((), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
@ -584,7 +584,7 @@ class StateDischargeTest(jtu.JaxTestCase):
state.ref_addupdate(a_ref, (0, 1),
jnp.ones(2, dtype=jnp.dtype('float32')))
return []
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
@ -605,7 +605,7 @@ class StateDischargeTest(jtu.JaxTestCase):
state.ref_addupdate(a_ref, (jnp.array([0, 1]),),
jnp.ones((2, 3), 'float32'))
return []
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, discharged_consts = state.discharge_state(
@ -619,7 +619,7 @@ class StateDischargeTest(jtu.JaxTestCase):
a = state.ref_get(a_ref, ())
b = a + 1
return [a, b]
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
in_avals = [state.shaped_array_ref((4,), jnp.dtype('float32'))]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
@ -637,8 +637,8 @@ class StateDischargeTest(jtu.JaxTestCase):
state.ref_set(b_ref, (), jnp.ones(4, jnp.float32))
return []
in_avals = [
state.ShapedArrayRef((4,), jnp.dtype('float32')),
state.ShapedArrayRef((4,), jnp.dtype('float32'))
state.shaped_array_ref((4,), jnp.dtype('float32')),
state.shaped_array_ref((4,), jnp.dtype('float32'))
]
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
in_avals)
@ -646,7 +646,7 @@ class StateDischargeTest(jtu.JaxTestCase):
stateful_jaxpr, consts, should_discharge=[False, True])
self.assertLen(discharged_jaxpr.invars, 2)
self.assertLen(discharged_jaxpr.outvars, 1)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.ShapedArrayRef)
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef)
self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray)
self.assertEqual(discharged_jaxpr.effects,
{state.WriteEffect(len(discharged_jaxpr.constvars))})
@ -659,7 +659,7 @@ class StateDischargeTest(jtu.JaxTestCase):
ref[...]
return []
in_avals = [state.ShapedArrayRef((), jnp.float32)]
in_avals = [state.shaped_array_ref((), jnp.float32)]
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
@ -672,7 +672,7 @@ if CAN_USE_HYPOTHESIS:
Shape = tuple[int, ...]
class IndexParam(NamedTuple):
ref_aval: state.ShapedArrayRef
ref_aval: state.shaped_array_ref
ref_shape: Shape
indexed_dims: list[bool]
idx_avals: tuple[core.ShapedArray, ...]
@ -692,7 +692,7 @@ if CAN_USE_HYPOTHESIS:
slice_shape = (*idx_shape, *sliced_shape)
else:
slice_shape = ref_shape
ref_aval = state.ShapedArrayRef(ref_shape, np.float32)
ref_aval = state.shaped_array_ref(ref_shape, np.float32)
idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in
range(sum(indexed_dims)))
slice_aval = core.ShapedArray(slice_shape, np.float32)
@ -704,7 +704,7 @@ if CAN_USE_HYPOTHESIS:
ref_bdim: Optional[int]
non_slice_idx_bdims: tuple[Optional[int], ...]
slice_bdim: int
bat_ref_aval: state.ShapedArrayRef
bat_ref_aval: state.shaped_array_ref
bat_ref_shape: Shape
bat_non_slice_idx_avals: tuple[core.ShapedArray, ...]
bat_non_slice_idx_shapes: tuple[Shape, ...]
@ -752,7 +752,7 @@ if CAN_USE_HYPOTHESIS:
min_value=0, max_value=len(index_param.slice_shape)))
bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size)
bat_ref_aval = state.ShapedArrayRef(bat_ref_shape, np.float32)
bat_ref_aval = state.shaped_array_ref(bat_ref_shape, np.float32)
bat_non_slice_idx_avals = tuple(
core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes)
bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size)
@ -1025,5 +1025,39 @@ class StateControlFlowTest(jtu.JaxTestCase):
with self.assertRaises(NotImplementedError):
jax.grad(f)(3.)
class GeneralRefTest(jtu.JaxTestCase):
def test_unshaped_ref(self):
def f(x_ref):
x = x_ref[...]
x_ref[...] = x
state.ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [state.AbstractRef(core.UnshapedArray(jnp.int32))])
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
def test_token(self):
def f(x_ref):
x = x_ref[...]
x_ref[...] = x
state.ref_addupdate(x_ref, (), x)
return [x]
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f), [state.AbstractRef(core.AbstractToken())])
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
def test_ref_of_ref(self):
def f(x_ref_ref):
x_ref = x_ref_ref[...]
return [x_ref]
# Not sure why you'd ever want to do this, but it works!
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(f),
[state.AbstractRef(state.AbstractRef(core.ShapedArray((), jnp.int32)))])
self.assertIs(type(jaxpr.outvars[0].aval), state.AbstractRef)
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())