mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Refactor Ref
abstract type to contain other AbstractValue
s
This commit is contained in:
parent
8d0bdd2670
commit
4960e656af
@ -237,7 +237,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
if any(isinstance(op_aval, state.ShapedArrayRef) for op_aval in ops_avals):
|
||||
if any(isinstance(op_aval, state.AbstractRef) for op_aval in ops_avals):
|
||||
raise ValueError("Cannot pass `Ref`s into `cond`.")
|
||||
true_jaxpr, false_jaxpr = jaxprs
|
||||
out_tree, false_out_tree = out_trees
|
||||
@ -866,5 +866,5 @@ def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
|
||||
new_invals = []
|
||||
for aval in in_avals:
|
||||
new_invals.append(
|
||||
next(ref_val_iter) if isinstance(aval, state.ShapedArrayRef) else None)
|
||||
next(ref_val_iter) if isinstance(aval, state.AbstractRef) else None)
|
||||
return new_invals, out_vals
|
||||
|
@ -54,7 +54,7 @@ ReadEffect = state.ReadEffect
|
||||
WriteEffect = state.WriteEffect
|
||||
AccumEffect = state.AccumEffect
|
||||
StateEffect = state.StateEffect
|
||||
ShapedArrayRef = state.ShapedArrayRef
|
||||
AbstractRef = state.AbstractRef
|
||||
ref_set = state.ref_set
|
||||
ref_get = state.ref_get
|
||||
ref_addupdate = state.ref_addupdate
|
||||
@ -70,10 +70,10 @@ for_p.multiple_results = True
|
||||
|
||||
def _hoist_consts_to_refs(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
all_const_avals = [var.aval for var in jaxpr.constvars]
|
||||
is_const_ref = [isinstance(var.aval, ShapedArrayRef) for var in
|
||||
is_const_ref = [isinstance(var.aval, AbstractRef) for var in
|
||||
jaxpr.constvars]
|
||||
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||||
const_avals = [ShapedArrayRef(aval.shape, aval.dtype) for aval in const_avals] # pytype: disable=attribute-error
|
||||
const_avals = map(AbstractRef, const_avals)
|
||||
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||||
i_aval, *arg_avals = [var.aval for var in jaxpr.invars]
|
||||
in_avals = [i_aval, *merged_const_avals, *arg_avals]
|
||||
@ -100,11 +100,11 @@ def _trace_to_jaxpr_with_refs(f, state_tree: PyTreeDef,
|
||||
f, state_avals)
|
||||
return jaxpr, consts, out_tree_thunk()
|
||||
|
||||
def val_to_ref_aval(x) -> ShapedArrayRef:
|
||||
def val_to_ref_aval(x) -> AbstractRef:
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
if type(aval) is not core.ShapedArray:
|
||||
raise Exception(f"can't make ref from {x}")
|
||||
return ShapedArrayRef(aval.shape, aval.dtype)
|
||||
return AbstractRef(aval)
|
||||
|
||||
def for_loop(nsteps: Union[int, Sequence[int]],
|
||||
body: Callable[[Array, Ref[S]], None], init_state: S,
|
||||
@ -252,7 +252,7 @@ def _for_abstract_eval(*avals, jaxpr, **__):
|
||||
aval_effects = [set(eff.replace(input_index=eff.input_index - 1)
|
||||
for eff in effs) for aval, effs
|
||||
in zip(avals, jaxpr_aval_effects)
|
||||
if isinstance(aval, ShapedArrayRef)]
|
||||
if isinstance(aval, AbstractRef)]
|
||||
nonlocal_state_effects = core.join_effects(*aval_effects)
|
||||
return list(avals), nonlocal_state_effects
|
||||
|
||||
@ -266,7 +266,7 @@ def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
|
||||
unroll=unroll)
|
||||
new_invals = []
|
||||
for aval, out_val in zip(in_avals, out_vals):
|
||||
new_invals.append(out_val if isinstance(aval, ShapedArrayRef) else None)
|
||||
new_invals.append(out_val if isinstance(aval, AbstractRef) else None)
|
||||
return new_invals, out_vals
|
||||
|
||||
def _for_impl(*args, jaxpr, nsteps, reverse, which_linear, unroll):
|
||||
@ -661,10 +661,11 @@ def _convert_outputs_to_writes(
|
||||
res_ref[i] = res_val
|
||||
return []
|
||||
# TODO(mattjj, sharadmv): better handling of tokens, which don't have shape/dtype
|
||||
res_ref_avals = [ShapedArrayRef(v.aval.shape, v.aval.dtype) # pytype: disable=attribute-error
|
||||
if loop_invar else
|
||||
ShapedArrayRef((nsteps, *v.aval.shape), v.aval.dtype) # pytype: disable=attribute-error
|
||||
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
|
||||
res_ref_avals: List[core.AbstractValue] = [
|
||||
AbstractRef(v.aval) if loop_invar else # pytype: disable=attribute-error
|
||||
AbstractRef(core.ShapedArray((nsteps, *v.aval.shape), # pytype: disable=attribute-error
|
||||
v.aval.dtype)) # pytype: disable=attribute-error
|
||||
for v, loop_invar in zip(jaxpr.outvars, loop_invar_res)]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [*in_avals, *res_ref_avals])
|
||||
assert not consts
|
||||
@ -685,9 +686,11 @@ def _convert_inputs_to_reads(
|
||||
|
||||
res_val_avals, (i_aval,), orig_ref_avals = \
|
||||
split_list([v.aval for v in jaxpr.invars], [num_res, 1])
|
||||
res_ref_avals = [ShapedArrayRef(aval.shape, aval.dtype) if loop_invar else
|
||||
ShapedArrayRef((nsteps, *aval.shape), aval.dtype) # pytype: disable=attribute-error
|
||||
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
|
||||
res_ref_avals: List[core.AbstractValue] = [
|
||||
AbstractRef(aval) if loop_invar else # pytype: disable=attribute-error
|
||||
AbstractRef(core.ShapedArray((nsteps, *aval.shape), # pytype: disable=attribute-error
|
||||
aval.dtype)) # pytype: disable=attribute-error
|
||||
for aval, loop_invar in zip(res_val_avals, loop_invar_res)]
|
||||
|
||||
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
|
||||
eval_jaxpr, [i_aval, *res_ref_avals, *orig_ref_avals])
|
||||
|
@ -12,9 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for state."""
|
||||
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
AccumEffect, StateEffect, RefEffect,
|
||||
get_ref_state_effects)
|
||||
get_ref_state_effects, shaped_array_ref)
|
||||
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
|
||||
ref_addupdate, get_p, swap_p,
|
||||
addupdate_p)
|
||||
|
@ -25,7 +25,7 @@ from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.state.types import ShapedArrayRef
|
||||
from jax._src.state.types import AbstractRef
|
||||
from jax._src.state.primitives import get_p, swap_p, addupdate_p
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
|
||||
@ -47,8 +47,8 @@ def discharge_state(jaxpr: core.Jaxpr, consts: Sequence[Any], * ,
|
||||
"""Converts a jaxpr that takes in `Ref`s into one that doesn't."""
|
||||
if isinstance(should_discharge, bool):
|
||||
should_discharge = [should_discharge] * len(jaxpr.invars)
|
||||
in_avals = [core.ShapedArray(v.aval.shape, v.aval.dtype)
|
||||
if type(v.aval) is ShapedArrayRef and d
|
||||
in_avals = [v.aval.inner_aval
|
||||
if type(v.aval) is AbstractRef and d
|
||||
else v.aval for v, d in zip(jaxpr.invars, should_discharge)]
|
||||
eval_jaxpr = lu.wrap_init(partial(_eval_jaxpr_discharge_state, jaxpr,
|
||||
should_discharge, consts))
|
||||
@ -83,7 +83,7 @@ def register_discharge_rule(prim: core.Primitive):
|
||||
return register
|
||||
|
||||
def _has_refs(eqn: core.JaxprEqn):
|
||||
return any(isinstance(v.aval, ShapedArrayRef) for v in eqn.invars)
|
||||
return any(isinstance(v.aval, AbstractRef) for v in eqn.invars)
|
||||
|
||||
def _eval_jaxpr_discharge_state(
|
||||
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
|
||||
@ -97,7 +97,7 @@ def _eval_jaxpr_discharge_state(
|
||||
|
||||
refs_to_discharge = set(id(v.aval) for v, d
|
||||
in zip(jaxpr.invars, should_discharge) if d
|
||||
and isinstance(v.aval, ShapedArrayRef))
|
||||
and isinstance(v.aval, AbstractRef))
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge
|
||||
@ -237,7 +237,7 @@ def _closed_call_discharge_rule(
|
||||
call_jaxpr=discharged_closed_jaxpr)
|
||||
out_vals, ref_vals = split_list(out_and_ref_vals, [num_outs])
|
||||
ref_vals_iter = iter(ref_vals)
|
||||
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, ShapedArrayRef)
|
||||
new_invals = tuple(next(ref_vals_iter) if isinstance(aval, AbstractRef)
|
||||
else None for aval in in_avals)
|
||||
assert next(ref_vals_iter, None) is None
|
||||
return new_invals, out_vals
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""Module for state primitives."""
|
||||
from functools import partial
|
||||
|
||||
from typing import Any, List, Protocol, Tuple, TypeVar, Union
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -27,24 +27,13 @@ from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.typing import Array
|
||||
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
|
||||
from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect,
|
||||
AccumEffect)
|
||||
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
|
||||
|
||||
|
||||
## General utilities
|
||||
|
||||
T = TypeVar('T')
|
||||
class Ref(Protocol):
|
||||
|
||||
@property
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
...
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
...
|
||||
|
||||
## JAX utilities
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -57,11 +46,11 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# or we can read using indices:
|
||||
# a = get_p.bind(x, 0, 1)
|
||||
# Staging out `a = get_p.bind(x)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# a:f32[3] <- x[]
|
||||
get_p = core.Primitive("get")
|
||||
|
||||
def _get_impl(ref: Ref, *idx: int, **_):
|
||||
def _get_impl(ref: AbstractRef, *idx: int, **_):
|
||||
del ref, idx
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
get_p.def_impl(_get_impl)
|
||||
@ -69,9 +58,18 @@ get_p.def_impl(_get_impl)
|
||||
Indexer = Tuple[Union[int, slice, Array], ...]
|
||||
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
|
||||
|
||||
def _is_trivial_indexer(idx: Indexer) -> bool:
|
||||
if idx is ...:
|
||||
return True
|
||||
if type(idx) is tuple:
|
||||
if len(idx) == 0:
|
||||
return True
|
||||
return len(idx) == 1 and idx[0] is ...
|
||||
return False
|
||||
|
||||
def _unpack_idx(idx: Indexer, ndim: int
|
||||
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
|
||||
if idx is ... or (type(idx) is tuple and len(idx) == 1 and idx[0] is ...):
|
||||
if _is_trivial_indexer(idx):
|
||||
idx = tuple(slice(None) for _ in range(ndim))
|
||||
indexed_dims_ = [type(i) != slice for i in idx]
|
||||
_, non_slice_idx = partition_list(indexed_dims_, idx)
|
||||
@ -88,9 +86,23 @@ def _get_slice_output_shape(in_shape: Tuple[int, ...],
|
||||
shape = (*shape_prefix, *shape_suffix)
|
||||
return shape
|
||||
|
||||
def ref_get(ref: Ref, idx: Indexer) -> Array:
|
||||
def _get_indexer(ref: AbstractRef, idx: Indexer
|
||||
) -> Tuple[Indexer, Tuple[bool, ...]]:
|
||||
if isinstance(ref.inner_aval, core.ShapedArray):
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
else:
|
||||
if not _is_trivial_indexer(idx):
|
||||
raise ValueError(
|
||||
f"Cannot use nontrivial slice on non-shaped `Ref`: {idx}.")
|
||||
non_slice_idx, indexed_dims = (), ()
|
||||
return non_slice_idx, indexed_dims
|
||||
|
||||
def ref_get(ref: Any, idx: Indexer) -> Array:
|
||||
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `get` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
|
||||
# `swap` mutates a `Ref`, setting its value and returns its previous value.
|
||||
@ -101,27 +113,30 @@ def ref_get(ref: Ref, idx: Indexer) -> Array:
|
||||
# `swap_p` also takes in index arguments following the value, i.e.:
|
||||
# _ = swap_p.bind(x, a, 0, 1)
|
||||
# Staging out `b = swap_p.bind(x, a)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` and the aval of `a` is
|
||||
# `Ref((3,), np.dtype('float32'))` and the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
|
||||
# b:f32[3], x:Ref{f32[3]} <- x, a
|
||||
# Staging out `_ = swap_p.bind(x, a, i, j)` where the aval of `x` is
|
||||
# `ShapedArrayRef((3,), np.dtype('float32'))` , the aval of `a` is
|
||||
# `Ref((3,), np.dtype('float32'))` , the aval of `a` is
|
||||
# `ShapedArray((3,), np.dtype('float32'))`, and the avals of both `i` and `j`
|
||||
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
|
||||
# x:Ref{f32[3]}[i, j] <- a
|
||||
swap_p = core.Primitive("swap")
|
||||
|
||||
def _swap_impl(ref: Ref, value: Array, *idx: int, **_):
|
||||
def _swap_impl(ref: AbstractRef, value: Array, *idx: int, **_):
|
||||
del ref, value, idx
|
||||
raise ValueError("Cannot run stateful primitive.")
|
||||
swap_p.def_impl(_swap_impl)
|
||||
|
||||
def ref_swap(ref: Ref, idx: Indexer, value: Array) -> Array:
|
||||
def ref_swap(ref: AbstractRef, idx: Indexer, value: Array) -> Array:
|
||||
"""Sets a `Ref`'s value and returns the original value."""
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `swap` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
|
||||
def ref_set(ref: Ref, idx: Indexer, value: Array) -> None:
|
||||
def ref_set(ref: AbstractRef, idx: Indexer, value: Array) -> None:
|
||||
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
|
||||
ref_swap(ref, idx, value)
|
||||
|
||||
@ -139,81 +154,108 @@ def ref_set(ref: Ref, idx: Indexer, value: Array) -> None:
|
||||
addupdate_p = core.Primitive('addupdate')
|
||||
addupdate_p.multiple_results = True
|
||||
|
||||
def _addupdate_impl(ref: Ref, value: Array, *idx: int):
|
||||
def _addupdate_impl(ref: AbstractRef, value: Array, *idx: int):
|
||||
del ref, idx, value
|
||||
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
|
||||
addupdate_p.def_impl(_addupdate_impl)
|
||||
|
||||
def ref_addupdate(ref: Ref, idx: Indexer, x: Array) -> None:
|
||||
def ref_addupdate(ref: AbstractRef, idx: Indexer, x: Array) -> None:
|
||||
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
|
||||
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
|
||||
ref_aval = core.get_aval(ref)
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"Can only call `addupdate` on a `Ref`: {ref}")
|
||||
non_slice_idx, indexed_dims = _get_indexer(ref, idx)
|
||||
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)
|
||||
|
||||
## get/set/addupdate abstract evaluation rules
|
||||
|
||||
def _get_abstract_eval(ref_aval: ShapedArrayRef, *idx, indexed_dims):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
def _get_abstract_eval(ref_aval: AbstractRef, *idx,
|
||||
indexed_dims):
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.")
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
||||
return (core.ShapedArray(shape, ref_aval.dtype), {ReadEffect(0)})
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
raise ValueError("`get` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
shape = _get_slice_output_shape(ref_aval.shape, idx_shapes, indexed_dims)
|
||||
out_aval = ref_aval.inner_aval.update(shape=shape)
|
||||
else:
|
||||
if idx:
|
||||
raise ValueError("Cannot index non-shaped array with nontrivial indices.")
|
||||
out_aval = ref_aval.inner_aval
|
||||
return (out_aval, {ReadEffect(0)})
|
||||
get_p.def_effectful_abstract_eval(_get_abstract_eval)
|
||||
|
||||
|
||||
def _swap_abstract_eval(ref_aval: ShapedArrayRef, val_aval: core.AbstractValue,
|
||||
def _swap_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
out_aval: core.AbstractValue
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.")
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
expected_output_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
return (core.ShapedArray(expected_output_shape, ref_aval.dtype),
|
||||
{WriteEffect(0)})
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
expected_output_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if expected_output_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `swap`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `swap`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
out_aval = core.ShapedArray(expected_output_shape, ref_aval.dtype)
|
||||
else:
|
||||
if idx:
|
||||
raise ValueError("`swap` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
out_aval = ref_aval.inner_aval
|
||||
return (out_aval, {WriteEffect(0)})
|
||||
swap_p.def_effectful_abstract_eval(_swap_abstract_eval)
|
||||
|
||||
|
||||
def _addupdate_abstract_eval(ref_aval: ShapedArrayRef,
|
||||
def _addupdate_abstract_eval(ref_aval: AbstractRef,
|
||||
val_aval: core.AbstractValue,
|
||||
*idx: core.ShapedArray, indexed_dims: Tuple[bool]):
|
||||
if not isinstance(ref_aval, ShapedArrayRef):
|
||||
if not isinstance(ref_aval, AbstractRef):
|
||||
raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.")
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
slice_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if slice_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `addupdate`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `addupdate`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
if idx and not isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
if isinstance(ref_aval.inner_aval, core.ShapedArray):
|
||||
if len(indexed_dims) != len(ref_aval.shape):
|
||||
raise ValueError("`indexed_dims` must be the same length as `Ref` shape.")
|
||||
if sum(indexed_dims) != len(idx):
|
||||
raise ValueError(f"Invalid `idx` and `indexed_dims`: {idx}, {indexed_dims}")
|
||||
val_aval = core.raise_to_shaped(val_aval)
|
||||
assert isinstance(val_aval, core.ShapedArray)
|
||||
idx_shapes = tuple(i.shape for i in idx)
|
||||
slice_shape = _get_slice_output_shape(
|
||||
ref_aval.shape, idx_shapes, indexed_dims)
|
||||
if slice_shape != val_aval.shape:
|
||||
raise ValueError("Invalid shape for `addupdate`. "
|
||||
f"Ref shape: {ref_aval.shape}. "
|
||||
f"Value shape: {val_aval.shape}. "
|
||||
f"Indices: {idx}. ")
|
||||
if ref_aval.dtype != val_aval.dtype:
|
||||
raise ValueError("Invalid dtype for `addupdate`. "
|
||||
f"Ref dtype: {ref_aval.dtype}. "
|
||||
f"Value shape: {val_aval.dtype}. ")
|
||||
elif idx:
|
||||
raise ValueError("`addupdate` with nontrivial indexing must be called "
|
||||
f"on `ShapedArray` `Ref`: {ref_aval}.")
|
||||
return [], {AccumEffect(0)}
|
||||
addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval)
|
||||
|
||||
@ -279,18 +321,18 @@ core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule
|
||||
|
||||
def _get_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
ref_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
assert isinstance(ref_tangent.aval, AbstractRef)
|
||||
return (get_p.bind(ref_primal, *idx, **params),
|
||||
get_p.bind(ref_tangent, *idx, **params)) # type: ignore[arg-type]
|
||||
ad.primitive_jvps[get_p] = _get_jvp
|
||||
|
||||
def _swap_jvp(primals: List[Any], tangents: List[Any], **params: Any):
|
||||
ref_primal, x_primal, *idx = primals
|
||||
assert isinstance(ref_primal.aval, ShapedArrayRef)
|
||||
assert isinstance(ref_primal.aval, AbstractRef)
|
||||
ref_tangent, x_tangent, *_ = tangents
|
||||
assert isinstance(ref_tangent.aval, ShapedArrayRef)
|
||||
assert isinstance(ref_tangent.aval, AbstractRef)
|
||||
x_tangent = ad_util.instantiate(x_tangent)
|
||||
return (swap_p.bind(ref_primal, x_primal, *idx, **params), # type: ignore[arg-type]
|
||||
swap_p.bind(ref_tangent, x_tangent, *idx, **params)) # type: ignore[arg-type]
|
||||
|
@ -14,13 +14,13 @@
|
||||
"""Module for state types."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Sequence, Set, Union
|
||||
from typing import Any, Generic, List, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import effects
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
|
||||
from jax._src.util import safe_map, safe_zip, prod
|
||||
|
||||
xc = xla_client
|
||||
xb = xla_bridge
|
||||
@ -74,23 +74,34 @@ StateEffect = Union[ReadEffect, WriteEffect, AccumEffect]
|
||||
|
||||
# ## `Ref`s
|
||||
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
# A `ShapedArrayRef` is a abstract value for mutable containers of array types
|
||||
class ShapedArrayRef(core.AbstractValue):
|
||||
__slots__ = ["shape", "dtype"]
|
||||
Aval = TypeVar("Aval", bound=core.AbstractValue)
|
||||
|
||||
def __init__(self, shape, dtype):
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
# We need an aval for `Ref`s so we can represent `get` and `swap` in Jaxprs.
|
||||
class AbstractRef(core.AbstractValue, Generic[Aval]):
|
||||
__slots__ = ["inner_aval"]
|
||||
|
||||
def __init__(self, inner_aval: core.AbstractValue):
|
||||
self.inner_aval = inner_aval
|
||||
|
||||
def join(self, other):
|
||||
assert core.symbolic_equal_shape(self.shape, other.shape)
|
||||
assert self.dtype == other.dtype
|
||||
return self
|
||||
assert isinstance(other, AbstractRef)
|
||||
return AbstractRef(self.inner_aval.join(other.inner_aval))
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
size = property(lambda self: prod(self.shape))
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
if not isinstance(self.inner_aval, core.ShapedArray):
|
||||
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.")
|
||||
return self.inner_aval.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if not isinstance(self.inner_aval, core.UnshapedArray):
|
||||
raise ValueError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.")
|
||||
return self.inner_aval.dtype
|
||||
|
||||
@core.aval_method
|
||||
@staticmethod
|
||||
def get(tracer, idx=()):
|
||||
@ -116,35 +127,39 @@ class ShapedArrayRef(core.AbstractValue):
|
||||
return ref_set(tracer, idx, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
a = core.ShapedArray(self.shape, self.dtype)
|
||||
return f'Ref{{{a.str_short()}}}'
|
||||
return f'Ref{{{self.inner_aval.str_short()}}}'
|
||||
|
||||
def at_least_vspace(self):
|
||||
return self
|
||||
return AbstractRef(self.inner_aval.at_least_vspace())
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape)
|
||||
return (type(self) is type(other) and self.inner_aval == other.inner_aval)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.shape, self.dtype))
|
||||
return hash((self.__class__, self.inner_aval))
|
||||
|
||||
def _ref_raise_to_shaped(ref_aval: AbstractRef, weak_type):
|
||||
return AbstractRef(core.raise_to_shaped(ref_aval.inner_aval, weak_type))
|
||||
core.raise_to_shaped_mappings[AbstractRef] = _ref_raise_to_shaped
|
||||
|
||||
core.raise_to_shaped_mappings[ShapedArrayRef] = lambda aval, _: aval
|
||||
def _map_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
|
||||
|
||||
def _map_ref(size, axis, aval):
|
||||
if axis is None: return aval
|
||||
return ShapedArrayRef(tuple_delete(aval.shape, axis), aval.dtype)
|
||||
def _unmap_ref(size, axis_name, axis, ref_aval):
|
||||
return AbstractRef(core.unmapped_aval(size, axis_name, axis,
|
||||
ref_aval.inner_aval))
|
||||
|
||||
def _unmap_ref(size, axis_name, axis, aval):
|
||||
if axis is None: return aval
|
||||
return ShapedArrayRef(tuple_insert(aval.shape, axis, size), aval.dtype)
|
||||
|
||||
core.aval_mapping_handlers[ShapedArrayRef] = (_map_ref, _unmap_ref)
|
||||
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
|
||||
|
||||
def get_ref_state_effects(
|
||||
avals: Sequence[core.AbstractValue],
|
||||
effects: core.Effects) -> List[Set[StateEffect]]:
|
||||
return [{eff for eff in effects
|
||||
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
|
||||
and eff.input_index == i} for i, aval in enumerate(avals)]
|
||||
and eff.input_index == i} for i, _ in enumerate(avals)]
|
||||
|
||||
def shaped_array_ref(shape: Tuple[int, ...], dtype,
|
||||
weak_type: bool = False,
|
||||
named_shape = None) -> AbstractRef[core.ShapedArray]:
|
||||
return AbstractRef(core.ShapedArray(shape, dtype, weak_type=weak_type,
|
||||
named_shape=named_shape))
|
||||
|
@ -88,7 +88,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
)
|
||||
def test_get_abstract_eval(self, ref_shape, ref_dtype, idx, out_shape=None,
|
||||
out_dtype=None, should_error=False):
|
||||
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
||||
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
def f(x_ref):
|
||||
out = state.ref_get(x_ref, idx)
|
||||
return [out]
|
||||
@ -154,7 +154,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def test_swap_abstract_eval(self, ref_shape, ref_dtype,
|
||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
||||
should_error=False):
|
||||
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
||||
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||
def f(x_ref, val):
|
||||
out = state.ref_swap(x_ref, idx, val)
|
||||
@ -210,7 +210,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def test_addupdate_abstract_eval(self, ref_shape, ref_dtype,
|
||||
val_shape, val_dtype, idx, out_shape=None, out_dtype=None,
|
||||
should_error=False):
|
||||
ref_aval = state.ShapedArrayRef(ref_shape, ref_dtype)
|
||||
ref_aval = state.AbstractRef(core.ShapedArray(ref_shape, ref_dtype))
|
||||
val_aval = core.ShapedArray(val_shape, val_dtype)
|
||||
def f(x_ref, val):
|
||||
state.ref_addupdate(x_ref, idx, val)
|
||||
@ -240,7 +240,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x[()] = jnp.int32(2)
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.swap_p)
|
||||
@ -253,7 +253,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
state.ref_addupdate(x, (), jnp.int32(1))
|
||||
return (x[()],)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
self.assertLen(consts, 0)
|
||||
self.assertListEqual(out_avals, [core.ShapedArray((), jnp.int32)])
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
||||
@ -263,14 +263,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x = x_ref[()]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[] <- a[]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref):
|
||||
x = x_ref[:, 0]
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32)])
|
||||
self.assertIn("b:i32[1] <- a[:,0]", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def test_set_custom_pretty_printing_rule(self):
|
||||
@ -278,14 +278,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x_ref[()] = jnp.int32(2)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("a[] <- 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x_ref[:, 0] = val
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
||||
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] <- b", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -294,14 +294,14 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
x = state.ref_swap(x_ref, (), jnp.int32(2))
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
self.assertIn("b:i32[], a[] <- a[], 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
def body(x_ref, val):
|
||||
x = state.ref_swap(x_ref, (slice(None), 0), val)
|
||||
return [x]
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
||||
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("c:i32[1], a[:,0] <- a[:,0], b",
|
||||
jaxpr.pretty_print(use_color=False))
|
||||
@ -311,7 +311,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
state.ref_addupdate(x_ref, (), jnp.int32(2))
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((), jnp.int32)])
|
||||
lu.wrap_init(body), [state.shaped_array_ref((), jnp.int32)])
|
||||
|
||||
self.assertIn("a[] += 2", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -319,7 +319,7 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
state.ref_addupdate(x_ref, (slice(None), 0), val)
|
||||
return []
|
||||
jaxpr, _ , _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(body), [state.ShapedArrayRef((1, 2), jnp.int32),
|
||||
lu.wrap_init(body), [state.shaped_array_ref((1, 2), jnp.int32),
|
||||
core.ShapedArray((1,), jnp.int32)])
|
||||
self.assertIn("a[:,0] += b", jaxpr.pretty_print(use_color=False))
|
||||
|
||||
@ -333,8 +333,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
|
||||
state.shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
||||
@ -349,8 +349,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
|
||||
state.shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.get_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.get_p)
|
||||
@ -369,8 +369,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
def g(r, rdot):
|
||||
return jax.jvp(f, (r,), (rdot,))
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
|
||||
state.shaped_array_ref((), jnp.dtype('float32'))]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(lu.wrap_init(g), in_avals)
|
||||
self.assertEqual(jaxpr.eqns[0].primitive, state.addupdate_p)
|
||||
self.assertEqual(jaxpr.eqns[1].primitive, state.addupdate_p)
|
||||
@ -420,8 +420,8 @@ class StatePrimitivesTest(jtu.JaxTestCase):
|
||||
return tuple_insert(shape, idx, axis_size)
|
||||
|
||||
batched_ref_shape = maybe_insert(ref_shape, ref_bdim)
|
||||
ref_aval = state.ShapedArrayRef(ref_shape, float_)
|
||||
bat_ref_aval = state.ShapedArrayRef(batched_ref_shape, float_)
|
||||
ref_aval = state.shaped_array_ref(ref_shape, float_)
|
||||
bat_ref_aval = state.shaped_array_ref(batched_ref_shape, float_)
|
||||
|
||||
idx_avals = [core.ShapedArray(idx_shape, int_)
|
||||
for _ in idx_bdims]
|
||||
@ -465,7 +465,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref):
|
||||
a = state.ref_get(a_ref, ())
|
||||
return [a + 1]
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
@ -481,7 +481,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref):
|
||||
a = state.ref_get(a_ref, (0, 1))
|
||||
return [a + 1]
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
@ -500,7 +500,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref):
|
||||
a = a_ref[jnp.array([0, 1])]
|
||||
return [a + 1]
|
||||
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), in_avals)
|
||||
discharged_jaxpr, discharged_consts = state.discharge_state(
|
||||
@ -514,7 +514,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref, b):
|
||||
state.ref_set(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
@ -532,7 +532,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref):
|
||||
state.ref_set(a_ref, (0, 1), jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
# Discharging should just turn this into a jaxpr that just adds 1.
|
||||
@ -552,7 +552,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref):
|
||||
a_ref[jnp.array([0, 1])] = jnp.ones((2, 3), 'float32')
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = state.discharge_state(
|
||||
@ -565,7 +565,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
def f(a_ref, b):
|
||||
state.ref_addupdate(a_ref, (), b + 1)
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((), jnp.dtype('float32')),
|
||||
in_avals = [state.shaped_array_ref((), jnp.dtype('float32')),
|
||||
core.ShapedArray((), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
@ -584,7 +584,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
state.ref_addupdate(a_ref, (0, 1),
|
||||
jnp.ones(2, dtype=jnp.dtype('float32')))
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3, 2), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3, 2), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
@ -605,7 +605,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
state.ref_addupdate(a_ref, (jnp.array([0, 1]),),
|
||||
jnp.ones((2, 3), 'float32'))
|
||||
return []
|
||||
in_avals = [state.ShapedArrayRef((4, 3), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4, 3), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, discharged_consts = state.discharge_state(
|
||||
@ -619,7 +619,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
a = state.ref_get(a_ref, ())
|
||||
b = a + 1
|
||||
return [a, b]
|
||||
in_avals = [state.ShapedArrayRef((4,), jnp.dtype('float32'))]
|
||||
in_avals = [state.shaped_array_ref((4,), jnp.dtype('float32'))]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
discharged_jaxpr, _ = state.discharge_state(stateful_jaxpr, consts)
|
||||
@ -637,8 +637,8 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
state.ref_set(b_ref, (), jnp.ones(4, jnp.float32))
|
||||
return []
|
||||
in_avals = [
|
||||
state.ShapedArrayRef((4,), jnp.dtype('float32')),
|
||||
state.ShapedArrayRef((4,), jnp.dtype('float32'))
|
||||
state.shaped_array_ref((4,), jnp.dtype('float32')),
|
||||
state.shaped_array_ref((4,), jnp.dtype('float32'))
|
||||
]
|
||||
stateful_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
|
||||
in_avals)
|
||||
@ -646,7 +646,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
stateful_jaxpr, consts, should_discharge=[False, True])
|
||||
self.assertLen(discharged_jaxpr.invars, 2)
|
||||
self.assertLen(discharged_jaxpr.outvars, 1)
|
||||
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.ShapedArrayRef)
|
||||
self.assertIsInstance(discharged_jaxpr.invars[0].aval, state.AbstractRef)
|
||||
self.assertIsInstance(discharged_jaxpr.invars[1].aval, core.ShapedArray)
|
||||
self.assertEqual(discharged_jaxpr.effects,
|
||||
{state.WriteEffect(len(discharged_jaxpr.constvars))})
|
||||
@ -659,7 +659,7 @@ class StateDischargeTest(jtu.JaxTestCase):
|
||||
ref[...]
|
||||
return []
|
||||
|
||||
in_avals = [state.ShapedArrayRef((), jnp.float32)]
|
||||
in_avals = [state.shaped_array_ref((), jnp.float32)]
|
||||
pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), in_avals)
|
||||
|
||||
|
||||
@ -672,7 +672,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
Shape = tuple[int, ...]
|
||||
|
||||
class IndexParam(NamedTuple):
|
||||
ref_aval: state.ShapedArrayRef
|
||||
ref_aval: state.shaped_array_ref
|
||||
ref_shape: Shape
|
||||
indexed_dims: list[bool]
|
||||
idx_avals: tuple[core.ShapedArray, ...]
|
||||
@ -692,7 +692,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
slice_shape = (*idx_shape, *sliced_shape)
|
||||
else:
|
||||
slice_shape = ref_shape
|
||||
ref_aval = state.ShapedArrayRef(ref_shape, np.float32)
|
||||
ref_aval = state.shaped_array_ref(ref_shape, np.float32)
|
||||
idx_avals = tuple(core.ShapedArray(idx_shape, np.int32) for _ in
|
||||
range(sum(indexed_dims)))
|
||||
slice_aval = core.ShapedArray(slice_shape, np.float32)
|
||||
@ -704,7 +704,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
ref_bdim: Optional[int]
|
||||
non_slice_idx_bdims: tuple[Optional[int], ...]
|
||||
slice_bdim: int
|
||||
bat_ref_aval: state.ShapedArrayRef
|
||||
bat_ref_aval: state.shaped_array_ref
|
||||
bat_ref_shape: Shape
|
||||
bat_non_slice_idx_avals: tuple[core.ShapedArray, ...]
|
||||
bat_non_slice_idx_shapes: tuple[Shape, ...]
|
||||
@ -752,7 +752,7 @@ if CAN_USE_HYPOTHESIS:
|
||||
min_value=0, max_value=len(index_param.slice_shape)))
|
||||
|
||||
bat_ref_shape = maybe_tuple_insert(index_param.ref_shape, ref_bdim, axis_size)
|
||||
bat_ref_aval = state.ShapedArrayRef(bat_ref_shape, np.float32)
|
||||
bat_ref_aval = state.shaped_array_ref(bat_ref_shape, np.float32)
|
||||
bat_non_slice_idx_avals = tuple(
|
||||
core.ShapedArray(shape, np.int32) for shape in bat_non_slice_idx_shapes)
|
||||
bat_slice_shape = maybe_tuple_insert(index_param.slice_shape, slice_bdim, axis_size)
|
||||
@ -1025,5 +1025,39 @@ class StateControlFlowTest(jtu.JaxTestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
jax.grad(f)(3.)
|
||||
|
||||
class GeneralRefTest(jtu.JaxTestCase):
|
||||
|
||||
def test_unshaped_ref(self):
|
||||
def f(x_ref):
|
||||
x = x_ref[...]
|
||||
x_ref[...] = x
|
||||
state.ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [state.AbstractRef(core.UnshapedArray(jnp.int32))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.UnshapedArray)
|
||||
self.assertEqual(jaxpr.outvars[0].aval.dtype, jnp.dtype("int32"))
|
||||
|
||||
def test_token(self):
|
||||
def f(x_ref):
|
||||
x = x_ref[...]
|
||||
x_ref[...] = x
|
||||
state.ref_addupdate(x_ref, (), x)
|
||||
return [x]
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f), [state.AbstractRef(core.AbstractToken())])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), core.AbstractToken)
|
||||
|
||||
def test_ref_of_ref(self):
|
||||
def f(x_ref_ref):
|
||||
x_ref = x_ref_ref[...]
|
||||
return [x_ref]
|
||||
# Not sure why you'd ever want to do this, but it works!
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(f),
|
||||
[state.AbstractRef(state.AbstractRef(core.ShapedArray((), jnp.int32)))])
|
||||
self.assertIs(type(jaxpr.outvars[0].aval), state.AbstractRef)
|
||||
self.assertIs(type(jaxpr.outvars[0].aval.inner_aval), core.ShapedArray)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user