Use trace_to_jaxpr_dynamic for the apply_primitive path. trace_to_jaxpr_final is only for final style primitives. Also do some cleanup.

PiperOrigin-RevId: 586106427
This commit is contained in:
Yash Katariya 2023-11-28 14:35:00 -08:00 committed by jax authors
parent a66fea78b0
commit cb7c2ed848
4 changed files with 21 additions and 30 deletions

View File

@ -139,8 +139,8 @@ def _trace_to_jaxpr(fun, in_avals, api_name, fun_name):
with log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT):
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts), tuple(out_avals)
@util.cache()
@ -158,10 +158,11 @@ def xla_primitive_callable(
donated_invars = (False,) * len(in_avals)
wrapped_fun = lu.wrap_init(prim_fun)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
closed_jaxpr, out_avals = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
computation = sharded_lowering(
closed_jaxpr, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
inline=True, in_avals=in_avals, out_avals=out_avals,
in_shardings=orig_in_shardings.shardings,
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
if config.disable_jit.value:
@ -179,19 +180,17 @@ def xla_primitive_callable(
def sharded_lowering(
closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
out_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None],
lowering_parameters: mlir.LoweringParameters
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
closed_jaxpr, 'jit', name, in_shardings_unspec,
(UNSPECIFIED,) * len(out_avals), donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=None)
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
def simple_impl(prim):

View File

@ -87,7 +87,7 @@ from jax.experimental import pjit
from jax._src import core
from jax._src import test_util as jtu
from jax._src.interpreters import pxla
from jax._src.sharding_impls import UNSPECIFIED
from jax._src import xla_bridge as xb
@ -321,8 +321,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
in_shardings=(UNSPECIFIED,) * len(in_avals),
out_shardings=(UNSPECIFIED,) * len(out_avals),
lowering_platforms=(data.platform,),
ordered_effects=(),
unordered_effects=(),

View File

@ -67,9 +67,8 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified,
is_unspecified_or_auto
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
@ -1889,7 +1888,7 @@ def lower_sharding_computation(
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Sequence[MaybeSharding] | UnspecifiedValue,
out_shardings: Sequence[MaybeSharding],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
@ -1898,7 +1897,7 @@ def lower_sharding_computation(
devices_from_context: Sequence[xc.Device] | None = None,
lowering_parameters: mlir.LoweringParameters,
in_layouts: MaybeLayout,
out_layouts: Optional[MaybeLayout],
out_layouts: MaybeLayout,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
@ -1908,9 +1907,8 @@ def lower_sharding_computation(
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = (
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
auto_spmd_lowering = check_if_any_auto(
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
closed_jaxpr.jaxpr.debug_info)
@ -1923,12 +1921,6 @@ def lower_sharding_computation(
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
if is_unspecified(out_shardings):
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
if out_layouts is None:
out_layouts = (None,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert isinstance(out_layouts, tuple)
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
@ -1978,7 +1970,7 @@ def lower_sharding_computation(
# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,

View File

@ -702,9 +702,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
tiling_method=tiling_method,
lowering_parameters=lowering_parameters)
else:
closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
closed_jaxpr, out_avals = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
return dispatch.sharded_lowering(
closed_jaxpr, name, donated_invars, True, False, in_avals,
closed_jaxpr, name, donated_invars, True, False, in_avals, out_avals,
(None,) * len(in_avals), lowering_parameters=lowering_parameters)