[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit

This commit is contained in:
Matthew Johnson 2023-03-22 20:54:45 -07:00
parent f981243af5
commit 7743fcd758
5 changed files with 650 additions and 532 deletions

View File

@ -308,7 +308,7 @@ def jit(
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=None)
inline=inline, resource_env=None, abstracted_axes=abstracted_axes)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = pjit._pjit_explicit_sharding(

View File

@ -4710,4 +4710,21 @@ class BIntRules:
return core.DArray(aval, buf)
return handler
@staticmethod
def global_sharded_result_handler(aval, out_sharding, committed,
is_out_sharding_from_xla):
phys_aval, = BIntRules.physical_avals(aval)
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
if not dispatch.is_single_device_sharding(out_sharding):
raise NotImplementedError # TODO(mattjj)
else:
phys_sharding = out_sharding
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
is_out_sharding_from_xla)
def handler(bufs):
return core.DArray(aval, phys_handler(bufs))
return handler
core.bint._rules = BIntRules

View File

@ -31,9 +31,9 @@ from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters.pxla import PartitionSpec
from jax.tree_util import (
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple)
treedef_tuple, broadcast_prefix, all_leaves)
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
@ -66,6 +66,9 @@ from jax._src.util import (
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
merge_lists)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
class _FromGdaSingleton:
@ -162,7 +165,7 @@ def _get_arg_names(fun, in_tree, args_flat):
def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
arg_names):
arg_list = []
for a, n in safe_zip(args_flat, arg_names):
for a, n in zip(args_flat, arg_names):
da = a.sharding._device_assignment if hasattr(a, 'sharding') else None
arg_list.append((n, da, shaped_abstractify(a)))
@ -312,9 +315,6 @@ def _resolve_axis_resources_and_shardings_arg(
def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, static_argnums, static_argnames, device,
backend, abstracted_axes):
# TODO(yashkatariya, mattjj): Remove when pjit supports dynamic shapes.
if jax.config.jax_dynamic_shapes:
raise ValueError("Dynamic shapes is not supported with pjit yet.")
if abstracted_axes and not jax.config.jax_dynamic_shapes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
@ -414,12 +414,13 @@ class PjitInfo(NamedTuple):
keep_unused: bool
inline: bool
resource_env: Any
abstracted_axes: Optional[Any]
def common_infer_params(pjit_info_args, *args, **kwargs):
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
donate_argnums, device, backend, keep_unused, inline,
resource_env) = pjit_info_args
resource_env, abstracted_axes) = pjit_info_args
if kwargs and not _is_unspecified(user_in_shardings):
raise ValueError(
@ -435,6 +436,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
jit_name = 'jit' if resource_env is None else 'pjit'
dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames)
f = lu.wrap_init(fun)
@ -448,10 +451,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
# leads to wrong expansion.
if kwargs:
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
flat_fun, out_tree = flatten_fun(f, in_tree)
else:
args_flat, in_tree = tree_flatten(dyn_args)
explicit_args, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
dyn_kwargs = ()
del kwargs
@ -459,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
if donate_argnums and not jax.config.jax_debug_nans:
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
else:
donated_invars = (False,) * len(args_flat)
donated_invars = (False,) * len(explicit_args)
# If backend or device is set as an arg on jit, then resolve them to
# in_shardings and out_shardings as if user passed in in_shardings
@ -475,25 +478,37 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
del user_in_shardings, user_out_shardings
global_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
if config.jax_dynamic_shapes:
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple([a for a, e in in_type if e])
else:
in_type = in_avals = tuple(shaped_abstractify(a) for a in explicit_args)
canonicalized_in_shardings_flat = _process_in_axis_resources(
hashable_pytree(in_shardings), global_in_avals, in_tree, resource_env)
hashable_pytree(in_shardings), in_avals, in_tree, resource_env)
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
flat_fun, hashable_pytree(out_shardings), global_in_avals, dbg,
flat_fun, hashable_pytree(out_shardings), in_type, dbg,
HashableFunction(out_tree, closure=()),
HashableFunction(res_paths, closure=()))
if any(_is_from_gda(i) for i in canonicalized_in_shardings_flat):
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
canonicalized_in_shardings_flat, args_flat)
canonicalized_in_shardings_flat, explicit_args)
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
assert len(args_flat) == len(canonicalized_in_shardings_flat)
if config.jax_dynamic_shapes:
implicit_args = _extract_implicit_args(in_type, explicit_args)
else:
implicit_args = []
args_flat = [*implicit_args, *explicit_args]
canonicalized_in_shardings_flat = (
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
donated_invars = (False,) * len(consts) + donated_invars
num_extra_args = len(implicit_args) + len(consts)
canonicalized_in_shardings_flat = \
(_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
len(consts) + len(args_flat))
# in_shardings and out_shardings here are all GSPMDSharding.
params = dict(
@ -506,9 +521,48 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
keep_unused=keep_unused,
inline=inline,
)
return (consts + args_flat, global_in_avals, params, in_tree, out_tree(),
return (consts + args_flat, in_type, params, in_tree, out_tree(),
donate_argnums)
def _extract_implicit_args(
in_type: Sequence[Tuple[core.AbstractValue, bool]],
explicit_args: Sequence[Any]
) -> Sequence[core.Tracer]:
"""
Given an input type and explicitly-passed arguments (per the user-facing API
calling convention), extract implicit axis size arguments from shapes of
explicit arguments (for the trace-time / jaxpr-level calling convention).
"""
# First, using `in_type` construct a list to represent the full argument list,
# leaving the implicit arguments as None placeholders for now.
explicit_args_ = iter(explicit_args)
args = [next(explicit_args_) if expl else None for _, expl in in_type]
assert next(explicit_args_, None) is None
del explicit_args, explicit_args_
# Next, populate the implicit arguments using the DBIdxs in `in_type`.
for i, (aval, explicit) in enumerate(in_type):
if not explicit or not isinstance(aval, core.DShapedArray):
continue # can't populate an implicit argument
arg = args[i]
assert arg is not None
for d1, d2 in zip(aval.shape, arg.aval.shape):
if isinstance(d1, core.DBIdx):
if args[d1.val] is None:
args[d1.val] = d2
assert core.same_referent(args[d1.val], d2)
assert all(x is not None for x in args)
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
def _flat_axes_specs(abstracted_axes, *args, **kwargs
) -> Optional[List[pe.AbstractedAxesSpec]]:
if abstracted_axes is None: return None
if kwargs: raise NotImplementedError
def ax_leaf(l):
return (isinstance(l, dict) and all_leaves(l.values()) or
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
return broadcast_prefix(abstracted_axes, args, ax_leaf)
# in_shardings and out_shardings can't be None as the default value
# because `None` means that the input is fully replicated.
@ -683,7 +737,8 @@ def pjit(
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
device=device, backend=backend, keep_unused=keep_unused,
inline=inline, resource_env=resource_env)
inline=inline, resource_env=resource_env,
abstracted_axes=abstracted_axes)
return common_infer_params(pjit_info_args, *args, **kwargs)
has_explicit_sharding = _pjit_explicit_sharding(
@ -800,38 +855,44 @@ class PytreeLeaf:
@lru_cache(maxsize=4096)
def _process_in_axis_resources(in_shardings_thunk, global_in_avals,
def _process_in_axis_resources(in_shardings_thunk, in_type,
in_tree, resource_env):
orig_in_shardings = in_shardings_thunk()
# Only do this if original in_shardings are unspecified. If they are
# FROM_GDA or AUTO, go via flatten_axis_resources.
if _is_unspecified(orig_in_shardings):
in_shardings_flat = (orig_in_shardings,) * len(global_in_avals)
in_shardings_flat = (orig_in_shardings,) * len(in_type)
else:
in_shardings_flat = flatten_axis_resources(
"pjit in_shardings", in_tree, orig_in_shardings,
tupled_args=True)
pjit_check_aval_sharding(in_shardings_flat, global_in_avals, "pjit arguments",
allow_uneven_sharding=False)
if not config.jax_dynamic_shapes:
pjit_check_aval_sharding(in_shardings_flat, in_type,
"pjit arguments", allow_uneven_sharding=False)
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
# FROM_GDA is removed.
canonicalized_shardings = tuple(
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
for i, aval in zip(in_shardings_flat, in_type))
return canonicalized_shardings
@lu.cache
def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths):
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pjit in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
fun, global_in_avals, debug_info=pe_debug)
if config.jax_dynamic_shapes:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
lu.annotate(fun, in_type), debug_info=pe_debug)
else:
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
fun, in_type, debug_info=pe_debug)
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if not config.jax_dynamic_shapes:
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
if any(isinstance(c, core.Tracer) for c in consts):
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
@ -844,7 +905,7 @@ def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths):
@lru_cache(maxsize=4096)
def _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, global_out_avals):
out_shardings_thunk, out_tree, out_type):
orig_out_shardings = out_shardings_thunk()
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
# instead. This condition exists because flatten_axis_resources passes in an
@ -852,28 +913,29 @@ def _check_and_canonicalize_out_shardings(
# pytrees (which shouldn't exist but they do).
if (_is_unspecified(orig_out_shardings) or
isinstance(orig_out_shardings, XLACompatibleSharding)):
out_shardings_flat = (orig_out_shardings,) * len(global_out_avals)
out_shardings_flat = (orig_out_shardings,) * len(out_type)
else:
out_shardings_flat = flatten_axis_resources(
"pjit out_shardings", out_tree(), orig_out_shardings,
tupled_args=False)
pjit_check_aval_sharding(out_shardings_flat, global_out_avals, "pjit outputs",
allow_uneven_sharding=False)
if not config.jax_dynamic_shapes:
pjit_check_aval_sharding(out_shardings_flat, out_type, "pjit outputs",
allow_uneven_sharding=False)
canonicalized_out_shardings_flat = tuple(
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
for o, aval in zip(out_shardings_flat, out_type)
)
return canonicalized_out_shardings_flat
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, debug_info, out_tree,
def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree,
result_paths):
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
fun, global_in_avals, debug_info, result_paths)
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
fun, in_type, debug_info, result_paths)
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, tuple(global_out_avals))
out_shardings_thunk, out_tree, tuple(out_type))
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return jaxpr, final_consts, canonicalized_out_shardings_flat
@ -1118,7 +1180,7 @@ def _resolve_in_shardings(
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
resolved_in_shardings = []
for arg, pjit_in_s in safe_zip(args, pjit_in_shardings):
for arg, pjit_in_s in zip(args, pjit_in_shardings):
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
if hasattr(arg, 'sharding') else (_UNSPECIFIED, False))
if _is_unspecified(pjit_in_s):
@ -1201,7 +1263,7 @@ def _pjit_call_impl(*args, jaxpr,
distributed_debug_log(("Running pjit'd function", name),
("in_shardings", in_shardings),
("out_shardings", out_shardings),
("abstract args", list(map(xla.abstractify, args))),
("abstract args", map(xla.abstractify, args)),
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args)
@ -1255,7 +1317,7 @@ class SameDeviceAssignmentTuple:
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
else s == o
for s, o in safe_zip(self.shardings, other.shardings)) and
for s, o in zip(self.shardings, other.shardings)) and
self.device_assignment == other.device_assignment)
@ -1341,10 +1403,43 @@ def pjit_staging_rule(trace, *args, **params):
all(_is_unspecified(o) for o in params["out_shardings"])):
jaxpr = params['jaxpr']
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
elif config.jax_dynamic_shapes:
source_info = source_info_util.current()
out_tracers = []
for aval in _out_type(params['jaxpr']):
if type(aval) is core.DShapedArray:
shape = [args[d.val] if type(d) is core.InDBIdx else
out_tracers[d.val] if type(d) is core.OutDBIdx else
d for d in aval.shape]
aval = aval.update(shape=tuple(core.get_referent(d) for d in shape))
out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info))
eqn = core.new_jaxpr_eqn(
map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params,
params['jaxpr'].effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
else:
return trace.default_process_primitive(pjit_p, args, params)
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
# since it's actually not possible in general to infer the type from the term
def _out_type(jaxpr: core.ClosedJaxpr) -> List[core.AbstractValue]:
out = []
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars)
if type(x) is core.Var}
for x in jaxpr.jaxpr.outvars:
aval = x.aval
if type(aval) is core.DShapedArray:
shape = [core.InDBIdx(in_idx[d]) if d in in_idx else
core.OutDBIdx(out_idx[d]) if d in out_idx else
d for d in x.aval.shape]
aval = aval.update(shape=tuple(shape))
out.append(aval)
return out
def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
return core._check_call(ctx_factory, pjit_p, in_atoms,
dict(params, call_jaxpr=jaxpr.jaxpr))
@ -1360,14 +1455,14 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
keep_unused, inline):
effects = list(ctx.tokens_in.effects())
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
flat_output_types = util.flatten(output_types)
arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
for aval, i in safe_zip(ctx.avals_in, in_shardings)]
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in safe_zip(ctx.avals_out, out_shardings)]
for aval, o in zip(ctx.avals_out, out_shardings)]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
@ -1381,7 +1476,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_lowering_ir_args(args))
out_nodes = util.unflatten(call.results, safe_map(len, output_types))
out_nodes = util.unflatten(call.results, map(len, output_types))
tokens, out_nodes = split_list(out_nodes, [len(effects)])
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
ctx.set_tokens_out(tokens_out)
@ -1486,7 +1581,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
fwds_known: Tuple[Optional[int]]) -> core.ClosedJaxpr:
updated_jaxpr = known_jaxpr.jaxpr.replace(
outvars=[x for x, i in safe_zip(known_jaxpr.jaxpr.outvars, fwds_known)
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
if i is None])
return known_jaxpr.replace(jaxpr=updated_jaxpr)
@ -1505,7 +1600,7 @@ def _pjit_partial_eval(trace, *in_tracers,
num_residuals = len(res_avals)
def keep_where(l, should_keep):
return tuple(x for x, keep in zip(l, should_keep) if keep)
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
residual_shardings = (_UNSPECIFIED,) * num_residuals
# Compute the known outputs
@ -1526,7 +1621,7 @@ def _pjit_partial_eval(trace, *in_tracers,
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
fwds_known_user = [
fwd if _is_unspecified(os) else None
for os, fwd in safe_zip(known_user_out_shardings,
for os, fwd in zip(known_user_out_shardings,
fwds_known[:len(known_user_out_shardings)])]
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
del fwds_known_user
@ -1534,7 +1629,7 @@ def _pjit_partial_eval(trace, *in_tracers,
# Remove forwarded outvars and out_shardings
known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known))
known_out_shardings = tuple(
s for s, i in safe_zip(known_params['out_shardings'], fwds_known) if i is None)
s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None)
known_params['out_shardings'] = known_out_shardings
del known_out_shardings
@ -1694,7 +1789,7 @@ def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
eqn.params['jaxpr'], tuple(used_outputs))
def keep_where(xs, keeps):
return tuple(x for x, keep in safe_zip(xs, keeps) if keep)
return tuple(x for x, keep in zip(xs, keeps) if keep)
eqn_params = eqn.params
new_params = dict(
@ -1847,7 +1942,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim),
resource_env=resource_env,
unconstrained_dims=ud)
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
return tree_unflatten(tree, outs)
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
@ -1983,7 +2078,7 @@ def _maybe_replace_from_gda_with_pspec(
return to_gspmd_sharding(gda_sharding, ndim)
out = []
for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat):
for in_sharding_flat, arg in zip(in_shardings_flat, args_flat):
if is_auto(in_sharding_flat):
out.append(in_sharding_flat)
elif isinstance(arg, array.ArrayImpl):

View File

@ -1732,6 +1732,7 @@ class DynamicJaxprTrace(core.Trace):
tracer = self.frame.constid_to_tracer.get(id(c))
if tracer is None:
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
aval = self._lift_tracers_in_aval(aval)
tracer = self._new_const(aval, c)
return tracer
@ -1820,8 +1821,8 @@ class DynamicJaxprTrace(core.Trace):
for aval, _ in out_type:
if type(aval) is DShapedArray:
shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else
out_tracers[d.val] if type(d) is OutDBIdx else
d for d in aval.shape]
out_tracers[d.val] if type(d) is OutDBIdx else
d for d in aval.shape]
aval = aval.update(shape=tuple(get_referent(d) for d in shape))
out_tracers.append(DynamicJaxprTracer(self, aval, source_info))
invars = map(self.getvar, in_tracers)

File diff suppressed because it is too large Load Diff