mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #25442 from jakevdp:raise-to-shaped
PiperOrigin-RevId: 705556199
This commit is contained in:
commit
ea63aeab01
@ -28,7 +28,6 @@ ShapedArray = core.ShapedArray
|
||||
AbstractToken = core.AbstractToken
|
||||
abstract_token = core.abstract_token
|
||||
canonicalize_shape = core.canonicalize_shape
|
||||
raise_to_shaped = core.raise_to_shaped
|
||||
|
||||
numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
|
||||
dtypes.int4, np.int8, np.int16, np.int32, np.int64,
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, TypeVar
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
|
||||
from jax._src.core import Primitive, valid_jaxtype, get_aval
|
||||
from jax._src.tree_util import register_pytree_node, tree_map
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import safe_map
|
||||
@ -51,7 +51,7 @@ def zeros_like_aval(aval: core.AbstractValue) -> Array:
|
||||
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}
|
||||
|
||||
def zeros_like_jaxval(val):
|
||||
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))
|
||||
return zeros_like_aval(core.get_aval(val))
|
||||
|
||||
def instantiate(z: Zero | Array) -> Array:
|
||||
if isinstance(z, Zero):
|
||||
@ -67,7 +67,7 @@ class Zero:
|
||||
return f'Zero({self.aval})'
|
||||
@staticmethod
|
||||
def from_primal_value(val: Any) -> Zero:
|
||||
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
|
||||
return Zero(get_aval(val).to_tangent_aval())
|
||||
|
||||
register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))
|
||||
|
||||
|
@ -2356,7 +2356,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
f"len(devices) = {len(devices)}.")
|
||||
|
||||
def _device_put_sharded(*xs):
|
||||
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
|
||||
avals = [core.get_aval(x) for x in xs]
|
||||
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
|
||||
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
|
||||
if a1 != a2)
|
||||
@ -2418,7 +2418,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
"a non-empty sequence.")
|
||||
def _device_put_replicated(x):
|
||||
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
|
||||
core.raise_to_shaped(core.get_aval(x)))
|
||||
core.get_aval(x))
|
||||
assert isinstance(aval, ShapedArray)
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
if config.pmap_no_rank_reduction.value:
|
||||
|
@ -587,8 +587,7 @@ def _dtype(x):
|
||||
|
||||
def _shaped_abstractify_slow(x):
|
||||
try:
|
||||
return core.raise_to_shaped(
|
||||
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
|
||||
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
|
@ -387,9 +387,6 @@ def default_checkify_rule(primitive: core.Primitive, error: Error,
|
||||
error = _reduce_any_error(error)
|
||||
return error, out_vals
|
||||
|
||||
def get_shaped_aval(val):
|
||||
return core.raise_to_shaped(core.get_aval(val))
|
||||
|
||||
def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
|
||||
error: Error, *args) -> tuple[Error, list[core.Value]]:
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
@ -760,7 +757,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
|
||||
# Get the error-effects out of all branches so the cond can be called with
|
||||
# a merged error with all these effects.
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
||||
in_avals = map(core.get_aval, [*err_vals, *ops])
|
||||
def get_error_effects_from_jaxpr(jxpr):
|
||||
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
|
||||
*in_avals)
|
||||
@ -770,7 +767,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
|
||||
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
||||
|
||||
# Update branch jaxprs to be checkified jaxprs.
|
||||
in_avals = map(get_shaped_aval, [*err_vals, *ops])
|
||||
in_avals = map(core.get_aval, [*err_vals, *ops])
|
||||
new_branches, out_trees, _ = unzip3(
|
||||
jaxpr_to_checkify_jaxpr(
|
||||
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
|
||||
@ -792,11 +789,11 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
num_consts, num_carry, linear, unroll, _split_transpose):
|
||||
|
||||
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
||||
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
|
||||
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
|
||||
# Query body effects to create a merged error containing all effects (such
|
||||
# that in and out carried error are of the same type).
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *new_in_aval)
|
||||
|
||||
@ -804,7 +801,7 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
err_vals, err_tree = jtu.tree_flatten(merged_error)
|
||||
|
||||
# Create checked-jaxpr, with the needed pre-processing on the inputs.
|
||||
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *new_in_aval)
|
||||
|
||||
@ -840,7 +837,7 @@ def checkify_while_body_jaxpr(
|
||||
*body_jaxpr.in_avals])
|
||||
closed_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
err_vals = map(get_shaped_aval, err_vals)
|
||||
err_vals = map(core.get_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
|
||||
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
|
||||
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
|
||||
@ -882,7 +879,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
||||
|
||||
cond_in_flat = [*err_vals, *c_consts, *carry]
|
||||
cond_in_flat = map(get_shaped_aval, cond_in_flat)
|
||||
cond_in_flat = map(core.get_aval, cond_in_flat)
|
||||
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
|
||||
err_tree, *cond_in_flat)
|
||||
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
|
||||
@ -906,7 +903,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# jaxpr to checked_jaxpr
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = tuple(map(core.get_aval, new_vals_in))
|
||||
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
|
||||
err_tree, *in_avals)
|
||||
|
||||
@ -942,7 +939,7 @@ error_checks[pjit.pjit_p] = pjit_error_check
|
||||
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
|
||||
err_vals, err_tree = jtu.tree_flatten(error)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = tuple(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = tuple(map(core.get_aval, new_vals_in))
|
||||
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
|
||||
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
|
||||
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
|
||||
@ -963,7 +960,7 @@ def shard_map_error_check(
|
||||
# Replicated sharding for in errors.
|
||||
new_in_names = (*([{}] * num_error_vals), *in_names)
|
||||
new_vals_in = [*err_vals, *vals_in]
|
||||
in_avals = list(map(get_shaped_aval, new_vals_in))
|
||||
in_avals = list(map(core.get_aval, new_vals_in))
|
||||
for i, v in enumerate(in_avals):
|
||||
if not (sharder := core.shard_aval_handlers.get(type(v))):
|
||||
raise ValueError(f'Unsupported aval type: {type(v)}')
|
||||
|
@ -149,7 +149,7 @@ class custom_vmap:
|
||||
"using def_vmap.")
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
|
||||
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
||||
|
@ -1051,7 +1051,7 @@ def closure_convert(fun: Callable, *example_args) -> tuple[Callable, list[Any]]:
|
||||
from the closure.
|
||||
"""
|
||||
flat_args, in_tree = tree_flatten(example_args)
|
||||
in_avals = tuple(map(abstractify, flat_args))
|
||||
in_avals = tuple(map(core.get_aval, flat_args))
|
||||
if config.check_tracer_leaks.value:
|
||||
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
|
||||
else:
|
||||
@ -1111,9 +1111,6 @@ def partition_list(choice, lst):
|
||||
return [next(i2 if snd else i1) for snd in which]
|
||||
return out, merge
|
||||
|
||||
def abstractify(x):
|
||||
return core.get_aval(x)
|
||||
|
||||
|
||||
### Custom transposition
|
||||
|
||||
@ -1209,8 +1206,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
|
||||
f_in_tree = treedef_tuple((res_tree, lin_tree))
|
||||
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)
|
||||
|
||||
res_avals = map(abstractify, operands_res)
|
||||
lin_avals = map(abstractify, operands_lin)
|
||||
res_avals = map(core.get_aval, operands_res)
|
||||
lin_avals = map(core.get_aval, operands_lin)
|
||||
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
|
||||
f_jaxpr = _close_jaxpr(f_jaxpr)
|
||||
out_avals = f_jaxpr.out_avals
|
||||
|
@ -455,7 +455,7 @@ class custom_partitioning:
|
||||
f_, dyn_args = lu.wrap_init(self.fun), args
|
||||
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
|
||||
in_avals = [core.get_aval(x) for x in args_flat]
|
||||
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
|
||||
"custom_partitioning")
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
|
@ -518,7 +518,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
aval = core.get_aval(a)
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
|
||||
|
@ -1504,7 +1504,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
|
||||
"""Returns the shape and dtype of a jax.Array or a j"""
|
||||
if isinstance(a, jax.ShapeDtypeStruct):
|
||||
return a.shape, a.dtype
|
||||
aval = core.raise_to_shaped(core.get_aval(a))
|
||||
aval = core.get_aval(a)
|
||||
return aval.shape, aval.dtype
|
||||
|
||||
|
||||
|
@ -388,7 +388,7 @@ def ffi_call(
|
||||
f"custom_call_api_version < 4; got {custom_call_api_version}.")
|
||||
|
||||
def wrapped(*args: ArrayLike, **kwargs: Any):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
in_avals = [core.get_aval(x) for x in args]
|
||||
|
||||
if input_layouts is None:
|
||||
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
|
||||
|
@ -37,9 +37,6 @@ map, unsafe_map = safe_map, map
|
||||
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
|
||||
def _typecheck_param(prim, param, name, msg_required, pred):
|
||||
if not pred:
|
||||
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
|
||||
@ -91,7 +88,7 @@ def _initial_style_jaxprs_with_common_consts(
|
||||
return [], [], []
|
||||
|
||||
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
|
||||
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
|
||||
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
|
||||
# If we get a `Ref` in the consts, we know it must come from an outer
|
||||
# `run_state`. We also know if shouldn't be boxed up in another tracer.
|
||||
# We assert that it is in fact a DynamicJaxprTracer
|
||||
|
@ -49,7 +49,6 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_avals_short,
|
||||
_check_tree_and_avals,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
@ -135,7 +134,7 @@ def switch(index, branches: Sequence[Callable], *operands,
|
||||
return branches[int(index)](*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
branches, ops_tree, ops_avals, primitive_name='switch')
|
||||
@ -227,7 +226,7 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
return false_fun(*operands)
|
||||
|
||||
ops, ops_tree = tree_flatten(operands)
|
||||
ops_avals = tuple(map(_abstractify, ops))
|
||||
ops_avals = tuple(map(core.get_aval, ops))
|
||||
|
||||
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
|
||||
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
|
||||
|
@ -44,7 +44,7 @@ from jax._src.typing import Array
|
||||
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
|
||||
split_list, split_dict, weakref_lru_cache)
|
||||
from jax._src.lax.control_flow import loops
|
||||
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
|
||||
from jax._src.lax.control_flow.common import _initial_style_jaxpr
|
||||
import numpy as np
|
||||
|
||||
## JAX utilities
|
||||
@ -196,7 +196,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init_flat = tree_leaves(init)
|
||||
_, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(map(_abstractify, init_flat))
|
||||
carry_avals = tuple(map(core.get_aval, init_flat))
|
||||
jaxpr, _, out_tree = _initial_style_jaxpr(
|
||||
f, in_tree, carry_avals + x_avals, "scan")
|
||||
return jaxpr, out_tree
|
||||
|
@ -47,7 +47,7 @@ from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax import windowed_reductions
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify, _avals_short, _initial_style_jaxpr,
|
||||
_avals_short, _initial_style_jaxpr,
|
||||
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
|
||||
_typecheck_param)
|
||||
from jax._src.lax.other import logaddexp
|
||||
@ -275,7 +275,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
init_flat, init_tree = tree_flatten(init)
|
||||
in_flat, in_tree = tree_flatten((init, xs))
|
||||
|
||||
carry_avals = tuple(_map(_abstractify, init_flat))
|
||||
carry_avals = tuple(_map(core.get_aval, init_flat))
|
||||
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
|
||||
f, in_tree, (*carry_avals, *x_avals), "scan")
|
||||
out_tree_children = out_tree.children()
|
||||
@ -361,7 +361,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
if p else 'the input carry')
|
||||
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
|
||||
paths, in_carry_flat = unzip2(leaves_and_paths)
|
||||
in_avals = _map(_abstractify, in_carry_flat)
|
||||
in_avals = _map(core.get_aval, in_carry_flat)
|
||||
if in_carry_tree != out_carry_tree:
|
||||
try:
|
||||
out_carry = tree_unflatten(out_carry_tree, out_avals)
|
||||
@ -1321,7 +1321,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
|
||||
def _create_jaxpr(init_val):
|
||||
init_vals, in_tree = tree_flatten((init_val,))
|
||||
init_avals = tuple(_map(_abstractify, init_vals))
|
||||
init_avals = tuple(_map(core.get_aval, init_vals))
|
||||
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
|
||||
cond_fun, in_tree, init_avals, "while_cond")
|
||||
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
|
||||
|
@ -32,7 +32,6 @@ from jax._src.util import split_list, safe_map
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
_abstractify,
|
||||
_check_tree,
|
||||
_initial_style_jaxpr,
|
||||
)
|
||||
@ -87,7 +86,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
|
||||
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
|
||||
"""
|
||||
guess_flat, in_args_tree = tree_flatten((initial_guess,))
|
||||
guess_avals = tuple(_map(_abstractify, guess_flat))
|
||||
guess_avals = tuple(_map(core.get_aval, guess_flat))
|
||||
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
|
||||
f, in_args_tree, guess_avals)
|
||||
|
||||
@ -230,7 +229,7 @@ def custom_linear_solve(
|
||||
transpose_solve = solve
|
||||
|
||||
b_flat, in_args_tree = tree_flatten((b,))
|
||||
b_avals = tuple(_map(_abstractify, b_flat))
|
||||
b_avals = tuple(_map(core.get_aval, b_flat))
|
||||
|
||||
tree, = treedef_children(in_args_tree)
|
||||
|
||||
|
@ -139,7 +139,7 @@ def roll(
|
||||
@roll_p.def_abstract_eval
|
||||
def _roll_abstract_eval(x, shift, **_):
|
||||
del shift
|
||||
return jax_core.raise_to_shaped(x)
|
||||
return x
|
||||
|
||||
|
||||
def _roll_lowering_rule(
|
||||
|
@ -112,7 +112,6 @@ def _pad_values_to_block_dimension(value,
|
||||
return value
|
||||
|
||||
def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]:
|
||||
scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals)
|
||||
return tuple(
|
||||
primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals
|
||||
)
|
||||
@ -1151,7 +1150,7 @@ def checkify_pallas_kernel_body_jaxpr(
|
||||
grid_mapping: GridMapping) -> tuple[
|
||||
jax_core.ClosedJaxpr, tree_util.PyTreeDef, set[checkify.ErrorEffect]]:
|
||||
err_vals, err_tree = tree_util.tree_flatten(error)
|
||||
err_vals = map(checkify.get_shaped_aval, err_vals)
|
||||
err_vals = map(jax_core.get_aval, err_vals)
|
||||
flat_err_and_in_vals = [*err_vals, *body_jaxpr.in_avals]
|
||||
|
||||
with pallas_core.tracing_grid_env(grid_mapping.grid, ()):
|
||||
@ -1274,13 +1273,13 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
closed_jaxpr, enabled_errors, error, grid_mapping)
|
||||
error = error._add_placeholder_effects(error_effects)
|
||||
err_vals, err_in_tree = jax.tree.flatten(error)
|
||||
shaped_err_avals = map(checkify.get_shaped_aval, err_vals)
|
||||
shaped_err_avals = map(jax_core.get_aval, err_vals)
|
||||
|
||||
# Trace the kernel jaxpr to get a checkified jaxpr. This jaxpr will have
|
||||
# all enabled errors removed, but have the error as inputs and return values.
|
||||
input_avals = [v.aval for v in jaxpr.invars]
|
||||
num_err_vals = len(err_vals)
|
||||
shaped_input_avals = tuple(jax_core.raise_to_shaped(x) for x in input_avals)
|
||||
shaped_input_avals = tuple(input_avals)
|
||||
checkify_in_avals = [*shaped_err_avals,
|
||||
*shaped_input_avals]
|
||||
closed_kernel_jaxpr = pe.close_jaxpr(jaxpr)
|
||||
@ -1416,8 +1415,7 @@ def _trace_kernel_to_jaxpr(
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
|
||||
kernel_avals, debug)
|
||||
if consts:
|
||||
consts_avals = [jax_core.raise_to_shaped(jax_core.get_aval(c))
|
||||
for c in consts]
|
||||
consts_avals = [jax_core.get_aval(c) for c in consts]
|
||||
if any(not isinstance(aval, state.AbstractRef) for aval in consts_avals):
|
||||
raise ValueError(
|
||||
f"The kernel function in the pallas_call {name_and_src_info} "
|
||||
@ -1804,8 +1802,7 @@ def pallas_call(
|
||||
def wrapped(*args):
|
||||
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
|
||||
in_paths, flat_args = unzip2(flat_args_with_paths)
|
||||
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
|
||||
for a in flat_args)
|
||||
flat_in_avals = tuple(jax_core.get_aval(a) for a in flat_args)
|
||||
|
||||
flat_out_avals = tuple(_convert_out_shape_to_aval(v)
|
||||
for v in flat_out_shapes)
|
||||
|
@ -2351,8 +2351,7 @@ def _pjit_transpose(cts_in, *primals_in,
|
||||
*prune_type(ad.UndefinedPrimal, in_layouts, primals_in),
|
||||
*prune_type(ad.Zero, out_layouts, cts_in)
|
||||
)
|
||||
global_cts_in_avals = tuple(core.raise_to_shaped(core.get_aval(ct))
|
||||
for ct in primals_and_nz_cts_in)
|
||||
global_cts_in_avals = tuple(core.get_aval(ct) for ct in primals_and_nz_cts_in)
|
||||
|
||||
transpose_jaxpr, attrs_tracked = _pjit_transpose_trace(
|
||||
body, global_cts_in_avals)
|
||||
|
@ -1249,7 +1249,7 @@ def _gamma_batching_rule(batched_args, batch_dims, *, log_space):
|
||||
|
||||
random_gamma_p = core.Primitive('random_gamma')
|
||||
random_gamma_p.def_impl(_gamma_impl)
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
|
||||
random_gamma_p.def_abstract_eval(lambda key, a, **_: a)
|
||||
ad.defjvp2(
|
||||
random_gamma_p, None,
|
||||
lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
|
||||
|
@ -414,7 +414,7 @@ _ref_type_aval_mappings: dict[
|
||||
|
||||
def _default_value_to_ref_aval(x: Any) -> tuple[AbstractRef, Array]:
|
||||
# Default type mapping just creates an AbstractRef from the array's aval.
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
aval = core.get_aval(x)
|
||||
return AbstractRef(aval), x
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ def hoist_consts_to_refs(
|
||||
|
||||
|
||||
def val_to_ref_aval(x) -> AbstractRef:
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
aval = core.get_aval(x)
|
||||
if type(aval) is not core.ShapedArray:
|
||||
raise TypeError(f"can't make ref from {x}")
|
||||
return AbstractRef(aval)
|
||||
|
@ -57,7 +57,7 @@ def _ensure_tracked(trace: pe.DynamicJaxprTrace, obj: Any, attr: str):
|
||||
frame = trace.frame
|
||||
|
||||
def new_tracer(x):
|
||||
aval = core.raise_to_shaped(core.get_aval(x))
|
||||
aval = core.get_aval(x)
|
||||
tracer = pe.DynamicJaxprTracer(trace, aval, pe.source_info_util.current())
|
||||
var = frame.tracer_to_var[id(tracer)] = frame.newvar(aval)
|
||||
frame.attrs_vars.append(var)
|
||||
@ -151,7 +151,7 @@ ad.JVPTrace.process_getattr = _getattr_jvp
|
||||
|
||||
def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
attr_primals = [jax_getattr(o, a) for o, a in attrs]
|
||||
attr_avals = [core.raise_to_shaped(core.get_aval(p)) for p in attr_primals]
|
||||
attr_avals = [core.get_aval(p) for p in attr_primals]
|
||||
primals_flat, in_tree = tree_flatten(primals)
|
||||
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
|
||||
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
|
||||
@ -207,7 +207,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
|
||||
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
|
||||
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
|
||||
f_, *attr_primals, *primals_flat)
|
||||
attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval()
|
||||
attr_avals = [core.get_aval(jax_getattr(o, a)).to_tangent_aval()
|
||||
for o, a in attrs_out]
|
||||
f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
|
||||
attrs, attrs_out)
|
||||
|
@ -619,7 +619,7 @@ class GraphSerializationImpl(SerializationImpl):
|
||||
args_specs_flat, self.in_tree = tree_util.tree_flatten(
|
||||
(self.args_specs, self.kwargs_specs))
|
||||
self.args_avals_flat = tuple(
|
||||
map(lambda a: core.raise_to_shaped(core.get_aval(a)), args_specs_flat))
|
||||
map(core.get_aval, args_specs_flat))
|
||||
dim_vars = shape_poly.all_dim_vars(self.args_avals_flat)
|
||||
dim_values, _ = _interpret_fun_jax(
|
||||
partial(shape_poly.compute_dim_vars_from_arg_shapes,
|
||||
|
@ -737,7 +737,7 @@ def _sparsify_jaxpr(spenv, jaxpr, *spvalues):
|
||||
|
||||
args = spvalues_to_arrays(spenv, spvalues)
|
||||
args_flat, in_tree = tree_flatten(args)
|
||||
avals_flat = [core.raise_to_shaped(core.get_aval(arg)) for arg in args_flat]
|
||||
avals_flat = [core.get_aval(arg) for arg in args_flat]
|
||||
sp_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped, avals_flat)
|
||||
sp_jaxpr = pe.ClosedJaxpr(sp_jaxpr, consts)
|
||||
assert out_tree is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user