mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit
This commit is contained in:
parent
f981243af5
commit
7743fcd758
@ -308,7 +308,7 @@ def jit(
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=None)
|
||||
inline=inline, resource_env=None, abstracted_axes=abstracted_axes)
|
||||
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = pjit._pjit_explicit_sharding(
|
||||
|
@ -4710,4 +4710,21 @@ class BIntRules:
|
||||
return core.DArray(aval, buf)
|
||||
return handler
|
||||
|
||||
@staticmethod
|
||||
def global_sharded_result_handler(aval, out_sharding, committed,
|
||||
is_out_sharding_from_xla):
|
||||
phys_aval, = BIntRules.physical_avals(aval)
|
||||
phys_handler_maker = pxla.global_result_handlers[core.ShapedArray]
|
||||
|
||||
if not dispatch.is_single_device_sharding(out_sharding):
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
else:
|
||||
phys_sharding = out_sharding
|
||||
phys_handler = phys_handler_maker(phys_aval, phys_sharding, committed,
|
||||
is_out_sharding_from_xla)
|
||||
|
||||
def handler(bufs):
|
||||
return core.DArray(aval, phys_handler(bufs))
|
||||
return handler
|
||||
|
||||
core.bint._rules = BIntRules
|
||||
|
199
jax/_src/pjit.py
199
jax/_src/pjit.py
@ -31,9 +31,9 @@ from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
from jax.tree_util import (
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_tuple)
|
||||
treedef_tuple, broadcast_prefix, all_leaves)
|
||||
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
@ -66,6 +66,9 @@ from jax._src.util import (
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
merge_lists)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
class _FromGdaSingleton:
|
||||
@ -162,7 +165,7 @@ def _get_arg_names(fun, in_tree, args_flat):
|
||||
def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
|
||||
arg_names):
|
||||
arg_list = []
|
||||
for a, n in safe_zip(args_flat, arg_names):
|
||||
for a, n in zip(args_flat, arg_names):
|
||||
da = a.sharding._device_assignment if hasattr(a, 'sharding') else None
|
||||
arg_list.append((n, da, shaped_abstractify(a)))
|
||||
|
||||
@ -312,9 +315,6 @@ def _resolve_axis_resources_and_shardings_arg(
|
||||
def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
donate_argnums, static_argnums, static_argnames, device,
|
||||
backend, abstracted_axes):
|
||||
# TODO(yashkatariya, mattjj): Remove when pjit supports dynamic shapes.
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
raise ValueError("Dynamic shapes is not supported with pjit yet.")
|
||||
if abstracted_axes and not jax.config.jax_dynamic_shapes:
|
||||
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
|
||||
|
||||
@ -414,12 +414,13 @@ class PjitInfo(NamedTuple):
|
||||
keep_unused: bool
|
||||
inline: bool
|
||||
resource_env: Any
|
||||
abstracted_axes: Optional[Any]
|
||||
|
||||
|
||||
def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
(fun, user_in_shardings, user_out_shardings, static_argnums, static_argnames,
|
||||
donate_argnums, device, backend, keep_unused, inline,
|
||||
resource_env) = pjit_info_args
|
||||
resource_env, abstracted_axes) = pjit_info_args
|
||||
|
||||
if kwargs and not _is_unspecified(user_in_shardings):
|
||||
raise ValueError(
|
||||
@ -435,6 +436,8 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
"Mesh context manager should not be used with jit when backend or "
|
||||
"device is also specified as an argument to jit.")
|
||||
|
||||
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
||||
|
||||
jit_name = 'jit' if resource_env is None else 'pjit'
|
||||
dbg = debug_info(jit_name, fun, args, kwargs, static_argnums, static_argnames)
|
||||
f = lu.wrap_init(fun)
|
||||
@ -448,10 +451,10 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
# leads to wrong expansion.
|
||||
if kwargs:
|
||||
f, dyn_kwargs = argnames_partial_except(f, static_argnames, kwargs)
|
||||
args_flat, in_tree = tree_flatten((dyn_args, dyn_kwargs))
|
||||
explicit_args, in_tree = tree_flatten((dyn_args, dyn_kwargs))
|
||||
flat_fun, out_tree = flatten_fun(f, in_tree)
|
||||
else:
|
||||
args_flat, in_tree = tree_flatten(dyn_args)
|
||||
explicit_args, in_tree = tree_flatten(dyn_args)
|
||||
flat_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
||||
dyn_kwargs = ()
|
||||
del kwargs
|
||||
@ -459,7 +462,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
if donate_argnums and not jax.config.jax_debug_nans:
|
||||
donated_invars = donation_vector(donate_argnums, dyn_args, dyn_kwargs)
|
||||
else:
|
||||
donated_invars = (False,) * len(args_flat)
|
||||
donated_invars = (False,) * len(explicit_args)
|
||||
|
||||
# If backend or device is set as an arg on jit, then resolve them to
|
||||
# in_shardings and out_shardings as if user passed in in_shardings
|
||||
@ -475,25 +478,37 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
|
||||
del user_in_shardings, user_out_shardings
|
||||
|
||||
global_in_avals = tuple(shaped_abstractify(a) for a in args_flat)
|
||||
if config.jax_dynamic_shapes:
|
||||
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
|
||||
in_avals = tuple([a for a, e in in_type if e])
|
||||
else:
|
||||
in_type = in_avals = tuple(shaped_abstractify(a) for a in explicit_args)
|
||||
|
||||
canonicalized_in_shardings_flat = _process_in_axis_resources(
|
||||
hashable_pytree(in_shardings), global_in_avals, in_tree, resource_env)
|
||||
hashable_pytree(in_shardings), in_avals, in_tree, resource_env)
|
||||
|
||||
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
|
||||
flat_fun, hashable_pytree(out_shardings), global_in_avals, dbg,
|
||||
flat_fun, hashable_pytree(out_shardings), in_type, dbg,
|
||||
HashableFunction(out_tree, closure=()),
|
||||
HashableFunction(res_paths, closure=()))
|
||||
|
||||
if any(_is_from_gda(i) for i in canonicalized_in_shardings_flat):
|
||||
canonicalized_in_shardings_flat = _maybe_replace_from_gda_with_pspec(
|
||||
canonicalized_in_shardings_flat, args_flat)
|
||||
canonicalized_in_shardings_flat, explicit_args)
|
||||
assert len(explicit_args) == len(canonicalized_in_shardings_flat)
|
||||
|
||||
assert len(args_flat) == len(canonicalized_in_shardings_flat)
|
||||
if config.jax_dynamic_shapes:
|
||||
implicit_args = _extract_implicit_args(in_type, explicit_args)
|
||||
else:
|
||||
implicit_args = []
|
||||
args_flat = [*implicit_args, *explicit_args]
|
||||
|
||||
canonicalized_in_shardings_flat = (
|
||||
_UNSPECIFIED,) * len(consts) + canonicalized_in_shardings_flat
|
||||
donated_invars = (False,) * len(consts) + donated_invars
|
||||
num_extra_args = len(implicit_args) + len(consts)
|
||||
canonicalized_in_shardings_flat = \
|
||||
(_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
donated_invars = (False,) * num_extra_args + donated_invars
|
||||
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
|
||||
len(consts) + len(args_flat))
|
||||
|
||||
# in_shardings and out_shardings here are all GSPMDSharding.
|
||||
params = dict(
|
||||
@ -506,9 +521,48 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
keep_unused=keep_unused,
|
||||
inline=inline,
|
||||
)
|
||||
return (consts + args_flat, global_in_avals, params, in_tree, out_tree(),
|
||||
return (consts + args_flat, in_type, params, in_tree, out_tree(),
|
||||
donate_argnums)
|
||||
|
||||
def _extract_implicit_args(
|
||||
in_type: Sequence[Tuple[core.AbstractValue, bool]],
|
||||
explicit_args: Sequence[Any]
|
||||
) -> Sequence[core.Tracer]:
|
||||
"""
|
||||
Given an input type and explicitly-passed arguments (per the user-facing API
|
||||
calling convention), extract implicit axis size arguments from shapes of
|
||||
explicit arguments (for the trace-time / jaxpr-level calling convention).
|
||||
"""
|
||||
# First, using `in_type` construct a list to represent the full argument list,
|
||||
# leaving the implicit arguments as None placeholders for now.
|
||||
explicit_args_ = iter(explicit_args)
|
||||
args = [next(explicit_args_) if expl else None for _, expl in in_type]
|
||||
assert next(explicit_args_, None) is None
|
||||
del explicit_args, explicit_args_
|
||||
|
||||
# Next, populate the implicit arguments using the DBIdxs in `in_type`.
|
||||
for i, (aval, explicit) in enumerate(in_type):
|
||||
if not explicit or not isinstance(aval, core.DShapedArray):
|
||||
continue # can't populate an implicit argument
|
||||
arg = args[i]
|
||||
assert arg is not None
|
||||
for d1, d2 in zip(aval.shape, arg.aval.shape):
|
||||
if isinstance(d1, core.DBIdx):
|
||||
if args[d1.val] is None:
|
||||
args[d1.val] = d2
|
||||
assert core.same_referent(args[d1.val], d2)
|
||||
assert all(x is not None for x in args)
|
||||
return [x for x, (_, e) in zip(args, in_type) if not e] # type: ignore
|
||||
|
||||
def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
) -> Optional[List[pe.AbstractedAxesSpec]]:
|
||||
if abstracted_axes is None: return None
|
||||
if kwargs: raise NotImplementedError
|
||||
def ax_leaf(l):
|
||||
return (isinstance(l, dict) and all_leaves(l.values()) or
|
||||
isinstance(l, tuple) and all_leaves(l, lambda x: x is None))
|
||||
return broadcast_prefix(abstracted_axes, args, ax_leaf)
|
||||
|
||||
|
||||
# in_shardings and out_shardings can't be None as the default value
|
||||
# because `None` means that the input is fully replicated.
|
||||
@ -683,7 +737,8 @@ def pjit(
|
||||
out_shardings=out_shardings, static_argnums=static_argnums,
|
||||
static_argnames=static_argnames, donate_argnums=donate_argnums,
|
||||
device=device, backend=backend, keep_unused=keep_unused,
|
||||
inline=inline, resource_env=resource_env)
|
||||
inline=inline, resource_env=resource_env,
|
||||
abstracted_axes=abstracted_axes)
|
||||
return common_infer_params(pjit_info_args, *args, **kwargs)
|
||||
|
||||
has_explicit_sharding = _pjit_explicit_sharding(
|
||||
@ -800,38 +855,44 @@ class PytreeLeaf:
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _process_in_axis_resources(in_shardings_thunk, global_in_avals,
|
||||
def _process_in_axis_resources(in_shardings_thunk, in_type,
|
||||
in_tree, resource_env):
|
||||
orig_in_shardings = in_shardings_thunk()
|
||||
# Only do this if original in_shardings are unspecified. If they are
|
||||
# FROM_GDA or AUTO, go via flatten_axis_resources.
|
||||
if _is_unspecified(orig_in_shardings):
|
||||
in_shardings_flat = (orig_in_shardings,) * len(global_in_avals)
|
||||
in_shardings_flat = (orig_in_shardings,) * len(in_type)
|
||||
else:
|
||||
in_shardings_flat = flatten_axis_resources(
|
||||
"pjit in_shardings", in_tree, orig_in_shardings,
|
||||
tupled_args=True)
|
||||
|
||||
pjit_check_aval_sharding(in_shardings_flat, global_in_avals, "pjit arguments",
|
||||
allow_uneven_sharding=False)
|
||||
if not config.jax_dynamic_shapes:
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_type,
|
||||
"pjit arguments", allow_uneven_sharding=False)
|
||||
# TODO(yashkatariya): Only check for is_auto or _is_unspecified when
|
||||
# FROM_GDA is removed.
|
||||
canonicalized_shardings = tuple(
|
||||
i if _is_unspecified_or_from_gda_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in safe_zip(in_shardings_flat, global_in_avals))
|
||||
for i, aval in zip(in_shardings_flat, in_type))
|
||||
return canonicalized_shardings
|
||||
|
||||
|
||||
@lu.cache
|
||||
def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths):
|
||||
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pjit in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
pe_debug = debug_info and pe.debug_info_final(fun, debug_info.traced_for)
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
fun, global_in_avals, debug_info=pe_debug)
|
||||
if config.jax_dynamic_shapes:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic2(
|
||||
lu.annotate(fun, in_type), debug_info=pe_debug)
|
||||
else:
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
|
||||
fun, in_type, debug_info=pe_debug)
|
||||
|
||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
if not config.jax_dynamic_shapes:
|
||||
jaxpr = jaxpr_debug_info(jaxpr, debug_info, out_paths())
|
||||
|
||||
if any(isinstance(c, core.Tracer) for c in consts):
|
||||
closed_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr))
|
||||
@ -844,7 +905,7 @@ def _create_pjit_jaxpr(fun, global_in_avals, debug_info, out_paths):
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, global_out_avals):
|
||||
out_shardings_thunk, out_tree, out_type):
|
||||
orig_out_shardings = out_shardings_thunk()
|
||||
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
|
||||
# instead. This condition exists because flatten_axis_resources passes in an
|
||||
@ -852,28 +913,29 @@ def _check_and_canonicalize_out_shardings(
|
||||
# pytrees (which shouldn't exist but they do).
|
||||
if (_is_unspecified(orig_out_shardings) or
|
||||
isinstance(orig_out_shardings, XLACompatibleSharding)):
|
||||
out_shardings_flat = (orig_out_shardings,) * len(global_out_avals)
|
||||
out_shardings_flat = (orig_out_shardings,) * len(out_type)
|
||||
else:
|
||||
out_shardings_flat = flatten_axis_resources(
|
||||
"pjit out_shardings", out_tree(), orig_out_shardings,
|
||||
tupled_args=False)
|
||||
|
||||
pjit_check_aval_sharding(out_shardings_flat, global_out_avals, "pjit outputs",
|
||||
allow_uneven_sharding=False)
|
||||
if not config.jax_dynamic_shapes:
|
||||
pjit_check_aval_sharding(out_shardings_flat, out_type, "pjit outputs",
|
||||
allow_uneven_sharding=False)
|
||||
|
||||
canonicalized_out_shardings_flat = tuple(
|
||||
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
|
||||
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
|
||||
for o, aval in zip(out_shardings_flat, out_type)
|
||||
)
|
||||
return canonicalized_out_shardings_flat
|
||||
|
||||
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, debug_info, out_tree,
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree,
|
||||
result_paths):
|
||||
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
|
||||
fun, global_in_avals, debug_info, result_paths)
|
||||
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
|
||||
fun, in_type, debug_info, result_paths)
|
||||
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, tuple(global_out_avals))
|
||||
out_shardings_thunk, out_tree, tuple(out_type))
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat
|
||||
|
||||
@ -1118,7 +1180,7 @@ def _resolve_in_shardings(
|
||||
(None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat)))
|
||||
|
||||
resolved_in_shardings = []
|
||||
for arg, pjit_in_s in safe_zip(args, pjit_in_shardings):
|
||||
for arg, pjit_in_s in zip(args, pjit_in_shardings):
|
||||
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
|
||||
if hasattr(arg, 'sharding') else (_UNSPECIFIED, False))
|
||||
if _is_unspecified(pjit_in_s):
|
||||
@ -1201,7 +1263,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
distributed_debug_log(("Running pjit'd function", name),
|
||||
("in_shardings", in_shardings),
|
||||
("out_shardings", out_shardings),
|
||||
("abstract args", list(map(xla.abstractify, args))),
|
||||
("abstract args", map(xla.abstractify, args)),
|
||||
("fingerprint", fingerprint))
|
||||
try:
|
||||
return compiled.unsafe_call(*args)
|
||||
@ -1255,7 +1317,7 @@ class SameDeviceAssignmentTuple:
|
||||
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) # pytype: disable=attribute-error
|
||||
if isinstance(s, GSPMDSharding) and isinstance(o, GSPMDSharding)
|
||||
else s == o
|
||||
for s, o in safe_zip(self.shardings, other.shardings)) and
|
||||
for s, o in zip(self.shardings, other.shardings)) and
|
||||
self.device_assignment == other.device_assignment)
|
||||
|
||||
|
||||
@ -1341,10 +1403,43 @@ def pjit_staging_rule(trace, *args, **params):
|
||||
all(_is_unspecified(o) for o in params["out_shardings"])):
|
||||
jaxpr = params['jaxpr']
|
||||
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
|
||||
elif config.jax_dynamic_shapes:
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = []
|
||||
for aval in _out_type(params['jaxpr']):
|
||||
if type(aval) is core.DShapedArray:
|
||||
shape = [args[d.val] if type(d) is core.InDBIdx else
|
||||
out_tracers[d.val] if type(d) is core.OutDBIdx else
|
||||
d for d in aval.shape]
|
||||
aval = aval.update(shape=tuple(core.get_referent(d) for d in shape))
|
||||
out_tracers.append(pe.DynamicJaxprTracer(trace, aval, source_info))
|
||||
eqn = core.new_jaxpr_eqn(
|
||||
map(trace.getvar, args), map(trace.makevar, out_tracers), pjit_p, params,
|
||||
params['jaxpr'].effects, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
else:
|
||||
return trace.default_process_primitive(pjit_p, args, params)
|
||||
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
# TODO(mattjj): remove/trivialize this when jaxprs have type annotation on them,
|
||||
# since it's actually not possible in general to infer the type from the term
|
||||
def _out_type(jaxpr: core.ClosedJaxpr) -> List[core.AbstractValue]:
|
||||
out = []
|
||||
in_idx = {v: i for i, v in enumerate(jaxpr.jaxpr.invars)}
|
||||
out_idx = {x: i for i, x in enumerate(jaxpr.jaxpr.invars)
|
||||
if type(x) is core.Var}
|
||||
for x in jaxpr.jaxpr.outvars:
|
||||
aval = x.aval
|
||||
if type(aval) is core.DShapedArray:
|
||||
shape = [core.InDBIdx(in_idx[d]) if d in in_idx else
|
||||
core.OutDBIdx(out_idx[d]) if d in out_idx else
|
||||
d for d in x.aval.shape]
|
||||
aval = aval.update(shape=tuple(shape))
|
||||
out.append(aval)
|
||||
return out
|
||||
|
||||
|
||||
def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
|
||||
return core._check_call(ctx_factory, pjit_p, in_atoms,
|
||||
dict(params, call_jaxpr=jaxpr.jaxpr))
|
||||
@ -1360,14 +1455,14 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
out_shardings, resource_env, donated_invars,
|
||||
keep_unused, inline):
|
||||
effects = list(ctx.tokens_in.effects())
|
||||
output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out)
|
||||
output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
flat_output_types = util.flatten(output_types)
|
||||
|
||||
arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
|
||||
for aval, i in safe_zip(ctx.avals_in, in_shardings)]
|
||||
for aval, i in zip(ctx.avals_in, in_shardings)]
|
||||
result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
||||
for aval, o in safe_zip(ctx.avals_out, out_shardings)]
|
||||
for aval, o in zip(ctx.avals_out, out_shardings)]
|
||||
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
@ -1381,7 +1476,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(func.name.value),
|
||||
mlir.flatten_lowering_ir_args(args))
|
||||
out_nodes = util.unflatten(call.results, safe_map(len, output_types))
|
||||
out_nodes = util.unflatten(call.results, map(len, output_types))
|
||||
tokens, out_nodes = split_list(out_nodes, [len(effects)])
|
||||
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
|
||||
ctx.set_tokens_out(tokens_out)
|
||||
@ -1486,7 +1581,7 @@ ad.primitive_jvps[pjit_p] = _pjit_jvp
|
||||
def _known_jaxpr_fwd(known_jaxpr: core.ClosedJaxpr,
|
||||
fwds_known: Tuple[Optional[int]]) -> core.ClosedJaxpr:
|
||||
updated_jaxpr = known_jaxpr.jaxpr.replace(
|
||||
outvars=[x for x, i in safe_zip(known_jaxpr.jaxpr.outvars, fwds_known)
|
||||
outvars=[x for x, i in zip(known_jaxpr.jaxpr.outvars, fwds_known)
|
||||
if i is None])
|
||||
return known_jaxpr.replace(jaxpr=updated_jaxpr)
|
||||
|
||||
@ -1505,7 +1600,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
num_residuals = len(res_avals)
|
||||
|
||||
def keep_where(l, should_keep):
|
||||
return tuple(x for x, keep in zip(l, should_keep) if keep)
|
||||
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
|
||||
|
||||
residual_shardings = (_UNSPECIFIED,) * num_residuals
|
||||
# Compute the known outputs
|
||||
@ -1526,7 +1621,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
|
||||
fwds_known_user = [
|
||||
fwd if _is_unspecified(os) else None
|
||||
for os, fwd in safe_zip(known_user_out_shardings,
|
||||
for os, fwd in zip(known_user_out_shardings,
|
||||
fwds_known[:len(known_user_out_shardings)])]
|
||||
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
|
||||
del fwds_known_user
|
||||
@ -1534,7 +1629,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
# Remove forwarded outvars and out_shardings
|
||||
known_params['jaxpr'] = _known_jaxpr_fwd(known_params['jaxpr'], tuple(fwds_known))
|
||||
known_out_shardings = tuple(
|
||||
s for s, i in safe_zip(known_params['out_shardings'], fwds_known) if i is None)
|
||||
s for s, i in zip(known_params['out_shardings'], fwds_known) if i is None)
|
||||
known_params['out_shardings'] = known_out_shardings
|
||||
del known_out_shardings
|
||||
|
||||
@ -1694,7 +1789,7 @@ def dce_jaxpr_pjit_rule(used_outputs: List[bool], eqn: core.JaxprEqn
|
||||
eqn.params['jaxpr'], tuple(used_outputs))
|
||||
|
||||
def keep_where(xs, keeps):
|
||||
return tuple(x for x, keep in safe_zip(xs, keeps) if keep)
|
||||
return tuple(x for x, keep in zip(xs, keeps) if keep)
|
||||
|
||||
eqn_params = eqn.params
|
||||
new_params = dict(
|
||||
@ -1847,7 +1942,7 @@ def with_sharding_constraint(x, axis_resources=_UNSPECIFIED,
|
||||
outs = [sharding_constraint_p.bind(xf, sharding=to_gspmd_sharding(i, xf.ndim),
|
||||
resource_env=resource_env,
|
||||
unconstrained_dims=ud)
|
||||
for xf, i, ud in safe_zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
for xf, i, ud in zip(x_flat, shardings_flat, unconstrained_dims)]
|
||||
return tree_unflatten(tree, outs)
|
||||
|
||||
def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims):
|
||||
@ -1983,7 +2078,7 @@ def _maybe_replace_from_gda_with_pspec(
|
||||
return to_gspmd_sharding(gda_sharding, ndim)
|
||||
|
||||
out = []
|
||||
for in_sharding_flat, arg in safe_zip(in_shardings_flat, args_flat):
|
||||
for in_sharding_flat, arg in zip(in_shardings_flat, args_flat):
|
||||
if is_auto(in_sharding_flat):
|
||||
out.append(in_sharding_flat)
|
||||
elif isinstance(arg, array.ArrayImpl):
|
||||
|
@ -1732,6 +1732,7 @@ class DynamicJaxprTrace(core.Trace):
|
||||
tracer = self.frame.constid_to_tracer.get(id(c))
|
||||
if tracer is None:
|
||||
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
|
||||
aval = self._lift_tracers_in_aval(aval)
|
||||
tracer = self._new_const(aval, c)
|
||||
return tracer
|
||||
|
||||
@ -1820,8 +1821,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
for aval, _ in out_type:
|
||||
if type(aval) is DShapedArray:
|
||||
shape = [[*consts, *in_tracers][d.val] if type(d) is InDBIdx else
|
||||
out_tracers[d.val] if type(d) is OutDBIdx else
|
||||
d for d in aval.shape]
|
||||
out_tracers[d.val] if type(d) is OutDBIdx else
|
||||
d for d in aval.shape]
|
||||
aval = aval.update(shape=tuple(get_referent(d) for d in shape))
|
||||
out_tracers.append(DynamicJaxprTracer(self, aval, source_info))
|
||||
invars = map(self.getvar, in_tracers)
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user