Merge pull request #25442 from jakevdp:raise-to-shaped

PiperOrigin-RevId: 705556199
This commit is contained in:
jax authors 2024-12-12 10:43:17 -08:00
commit ea63aeab01
25 changed files with 50 additions and 67 deletions

View File

@ -28,7 +28,6 @@ ShapedArray = core.ShapedArray
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
dtypes.int4, np.int8, np.int16, np.int32, np.int64,

View File

@ -19,7 +19,7 @@ from typing import Any, TypeVar
from jax._src import core
from jax._src import traceback_util
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
from jax._src.core import Primitive, valid_jaxtype, get_aval
from jax._src.tree_util import register_pytree_node, tree_map
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
@ -51,7 +51,7 @@ def zeros_like_aval(aval: core.AbstractValue) -> Array:
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
def zeros_like_jaxval(val):
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))
return zeros_like_aval(core.get_aval(val))
def instantiate(z: Zero | Array) -> Array:
if isinstance(z, Zero):
@ -67,7 +67,7 @@ class Zero:
return f'Zero({self.aval})'
@staticmethod
def from_primal_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
return Zero(get_aval(val).to_tangent_aval())
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))

View File

@ -2356,7 +2356,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
f"len(devices) = {len(devices)}.")
def _device_put_sharded(*xs):
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
avals = [core.get_aval(x) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
if a1 != a2)
@ -2418,7 +2418,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
core.get_aval(x))
assert isinstance(aval, ShapedArray)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
if config.pmap_no_rank_reduction.value:

View File

@ -587,8 +587,7 @@ def _dtype(x):
def _shaped_abstractify_slow(x):
try:
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
except TypeError:
pass

View File

@ -387,9 +387,6 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
error = _reduce_any_error(error)
return error, out_vals
def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> tuple[Error, list[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
@ -760,7 +757,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
# Get the error-effects out of all branches so the cond can be called with
# a merged error with all these effects.
err_vals, err_tree = jtu.tree_flatten(error)
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
def get_error_effects_from_jaxpr(jxpr):
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
*in_avals)
@ -770,7 +767,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
err_vals, err_tree = jtu.tree_flatten(merged_error)
# Update branch jaxprs to be checkified jaxprs.
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
new_branches, out_trees, _ = unzip3(
jaxpr_to_checkify_jaxpr(
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
@ -792,11 +789,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll, _split_transpose):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
# Query body effects to create a merged error containing all effects (such
# that in and out carried error are of the same type).
err_vals, err_tree = jtu.tree_flatten(error)
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
@ -804,7 +801,7 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
err_vals, err_tree = jtu.tree_flatten(merged_error)
# Create checked-jaxpr, with the needed pre-processing on the inputs.
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)
@ -840,7 +837,7 @@ def checkify_while_body_jaxpr(
*body_jaxpr.in_avals])
closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(get_shaped_aval, err_vals)
err_vals = map(core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
@ -882,7 +879,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
cond_in_flat = [*err_vals, *c_consts, *carry]
cond_in_flat = map(get_shaped_aval, cond_in_flat)
cond_in_flat = map(core.get_aval, cond_in_flat)
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
err_tree, *cond_in_flat)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
@ -906,7 +903,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *in_avals)
@ -942,7 +939,7 @@ error_checks[pjit.pjit_p] = pjit_error_check
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
@ -963,7 +960,7 @@ def shard_map_error_check(
# Replicated sharding for in errors.
new_in_names = (*([{}] * num_error_vals), *in_names)
new_vals_in = [*err_vals, *vals_in]
in_avals = list(map(get_shaped_aval, new_vals_in))
in_avals = list(map(core.get_aval, new_vals_in))
for i, v in enumerate(in_avals):
if not (sharder := core.shard_aval_handlers.get(type(v))):
raise ValueError(f'Unsupported aval type: {type(v)}')

View File

@ -149,7 +149,7 @@ class custom_vmap:
"using def_vmap.")
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())

View File

@ -1051,7 +1051,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
from the closure.
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
in_avals = tuple(map(core.get_aval, flat_args))
if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
@ -1111,9 +1111,6 @@ def partition_list(choice, lst):
return [next(i2 if snd else i1) for snd in which]
return out, merge
def abstractify(x):
return core.get_aval(x)
### Custom transposition
@ -1209,8 +1206,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
f_in_tree = treedef_tuple((res_tree, lin_tree))
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
res_avals = map(abstractify, operands_res)
lin_avals = map(abstractify, operands_lin)
res_avals = map(core.get_aval, operands_res)
lin_avals = map(core.get_aval, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = f_jaxpr.out_avals

View File

@ -455,7 +455,7 @@ class custom_partitioning:
f_, dyn_args = lu.wrap_init(self.fun), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh

View File

@ -518,7 +518,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype

View File

@ -1504,7 +1504,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype

View File

@ -388,7 +388,7 @@ def ffi_call(
f"custom_call_api_version < 4; got {custom_call_api_version}.")
def wrapped(*args: ArrayLike, **kwargs: Any):
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
in_avals = [core.get_aval(x) for x in args]
if input_layouts is None:
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))

View File

@ -37,9 +37,6 @@ map, unsafe_map = safe_map, map
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)
def _abstractify(x):
return core.raise_to_shaped(core.get_aval(x))
def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
@ -91,7 +88,7 @@ def _initial_style_jaxprs_with_common_consts(
return [], [], []
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
# If we get a `Ref` in the consts, we know it must come from an outer
# `run_state`. We also know if shouldn't be boxed up in another tracer.
# We assert that it is in fact a DynamicJaxprTracer

View File

@ -49,7 +49,6 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
@ -135,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
return branches[int(index)](*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
@ -227,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
return false_fun(*operands)
ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')

View File

@ -44,7 +44,7 @@ from jax._src.typing import Array
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict, weakref_lru_cache)
from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
from jax._src.lax.control_flow.common import _initial_style_jaxpr
import numpy as np
## JAX utilities
@ -196,7 +196,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))
carry_avals = tuple(map(_abstractify, init_flat))
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
return jaxpr, out_tree

View File

@ -47,7 +47,7 @@ from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _initial_style_jaxpr,
_avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lax.other import logaddexp
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
@ -361,7 +361,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
if p else 'the input carry')
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
paths, in_carry_flat = unzip2(leaves_and_paths)
in_avals = _map(_abstractify, in_carry_flat)
in_avals = _map(core.get_aval, in_carry_flat)
if in_carry_tree != out_carry_tree:
try:
out_carry = tree_unflatten(out_carry_tree, out_avals)
@ -1321,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(

View File

@ -32,7 +32,6 @@ from jax._src.util import split_list, safe_map
import numpy as np
from jax._src.lax.control_flow.common import (
_abstractify,
_check_tree,
_initial_style_jaxpr,
)
@ -87,7 +86,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
guess_avals = tuple(_map(core.get_aval, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)
@ -230,7 +229,7 @@ def custom_linear_solve(
transpose_solve = solve
b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
b_avals = tuple(_map(core.get_aval, b_flat))
tree, = treedef_children(in_args_tree)

View File

@ -139,7 +139,7 @@ def roll(
@roll_p.def_abstract_eval
def _roll_abstract_eval(x, shift, **_):
del shift
return jax_core.raise_to_shaped(x)
return x
def _roll_lowering_rule(

View File

@ -112,7 +112,6 @@ def _pad_values_to_block_dimension(value,
return value
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
return tuple(
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
)
@ -1151,7 +1150,7 @@ def checkify_pallas_kernel_body_jaxpr(
grid_mapping: GridMapping) -> tuple[
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
err_vals, err_tree = tree_util.tree_flatten(error)
err_vals = map(checkify.get_shaped_aval, err_vals)
err_vals = map(jax_core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
@ -1274,13 +1273,13 @@ def pallas_call_checkify_rule(error: checkify.Error,
closed_jaxpr, enabled_errors, error, grid_mapping)
error = error._add_placeholder_effects(error_effects)
err_vals, err_in_tree = jax.tree.flatten(error)
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
shaped_err_avals = map(jax_core.get_aval, err_vals)
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
# all enabled errors removed, but have the error as inputs and return values.
input_avals = [v.aval for v in jaxpr.invars]
num_err_vals = len(err_vals)
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
shaped_input_avals = tuple(input_avals)
checkify_in_avals = [*shaped_err_avals,
*shaped_input_avals]
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
@ -1416,8 +1415,7 @@ def _trace_kernel_to_jaxpr(
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
kernel_avals, debug)
if consts:
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
for c in consts]
consts_avals = [jax_core.get_aval(c) for c in consts]
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
raise ValueError(
f"The kernel function in the pallas_call {name_and_src_info} "
@ -1804,8 +1802,7 @@ def pallas_call(
def wrapped(*args):
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_in_avals = tuple(jax_core.get_aval(a) for a in flat_args)
flat_out_avals = tuple(_convert_out_shape_to_aval(v)
for v in flat_out_shapes)

View File

@ -2351,8 +2351,7 @@ def _pjit_transpose(cts_in, *primals_in,
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
*prune_type(ad.Zero, out_layouts, cts_in)
)
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
for ct in primals_and_nz_cts_in)
global_cts_in_avals = tuple(core.get_aval(ct) for ct in primals_and_nz_cts_in)
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
body, global_cts_in_avals)

View File

@ -1249,7 +1249,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
random_gamma_p.def_abstract_eval(lambda key, a, **_: a)
ad.defjvp2(
random_gamma_p, None,
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))

View File

@ -414,7 +414,7 @@ _ref_type_aval_mappings: dict[
def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]:
# Default type mapping just creates an AbstractRef from the array's aval.
aval = core.raise_to_shaped(core.get_aval(x))
aval = core.get_aval(x)
return AbstractRef(aval), x

View File

@ -79,7 +79,7 @@ def hoist_consts_to_refs(
def val_to_ref_aval(x) -> AbstractRef:
aval = core.raise_to_shaped(core.get_aval(x))
aval = core.get_aval(x)
if type(aval) is not core.ShapedArray:
raise TypeError(f"can't make ref from {x}")
return AbstractRef(aval)

View File

@ -57,7 +57,7 @@ def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
frame = trace.frame
def new_tracer(x):
aval = core.raise_to_shaped(core.get_aval(x))
aval = core.get_aval(x)
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
frame.attrs_vars.append(var)
@ -151,7 +151,7 @@ ad.JVPTrace.process_getattr = _getattr_jvp
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
attr_avals = [core.raise_to_shaped(core.get_aval(p)) for p in attr_primals]
attr_avals = [core.get_aval(p) for p in attr_primals]
primals_flat, in_tree = tree_flatten(primals)
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
@ -207,7 +207,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval()
attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval()
for o, a in attrs_out]
f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
attrs, attrs_out)

View File

@ -619,7 +619,7 @@ class GraphSerializationImpl(SerializationImpl):
args_specs_flat, self.in_tree = tree_util.tree_flatten(
(self.args_specs, self.kwargs_specs))
self.args_avals_flat = tuple(
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
map(core.get_aval, args_specs_flat))
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
dim_values, _ = _interpret_fun_jax(
partial(shape_poly.compute_dim_vars_from_arg_shapes,

View File

@ -737,7 +737,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
args = spvalues_to_arrays(spenv, spvalues)
args_flat, in_tree = tree_flatten(args)
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
avals_flat = [core.get_aval(arg) for arg in args_flat]
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
assert out_tree is not None