diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 5cb547348..299a01a74 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -139,8 +139,8 @@ def _trace_to_jaxpr(fun, in_avals, api_name, fun_name): with log_elapsed_time( "Finished tracing + transforming {fun_name} in {elapsed_time} sec", fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT): - jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals) - return core.ClosedJaxpr(jaxpr, consts) + jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) + return core.ClosedJaxpr(jaxpr, consts), tuple(out_avals) @util.cache() @@ -158,10 +158,11 @@ def xla_primitive_callable( donated_invars = (False,) * len(in_avals) wrapped_fun = lu.wrap_init(prim_fun) flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree) - closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name) + closed_jaxpr, out_avals = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name) computation = sharded_lowering( closed_jaxpr, prim.name, donated_invars, keep_unused=False, - inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings, + inline=True, in_avals=in_avals, out_avals=out_avals, + in_shardings=orig_in_shardings.shardings, lowering_parameters=mlir.LoweringParameters()) compiled = computation.compile() if config.disable_jit.value: @@ -179,19 +180,17 @@ def xla_primitive_callable( def sharded_lowering( closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool], keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...], + out_avals: tuple[core.AbstractValue, ...], in_shardings: Sequence[Sharding | None], lowering_parameters: mlir.LoweringParameters ) -> pxla.MeshComputation: in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings] - - # Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know - # the number of output avals at this stage. lower_sharding_computation will - # apply it to all out_avals. return pxla.lower_sharding_computation( - closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, + closed_jaxpr, 'jit', name, in_shardings_unspec, + (UNSPECIFIED,) * len(out_avals), donated_invars, in_avals, keep_unused=keep_unused, inline=inline, devices_from_context=None, lowering_parameters=lowering_parameters, - in_layouts=(None,) * len(in_avals), out_layouts=None) + in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals)) def simple_impl(prim): diff --git a/jax/_src/internal_test_util/export_back_compat_test_util.py b/jax/_src/internal_test_util/export_back_compat_test_util.py index 42757ed50..00e1dc370 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_util.py +++ b/jax/_src/internal_test_util/export_back_compat_test_util.py @@ -87,7 +87,7 @@ from jax.experimental import pjit from jax._src import core from jax._src import test_util as jtu -from jax._src.interpreters import pxla +from jax._src.sharding_impls import UNSPECIFIED from jax._src import xla_bridge as xb @@ -321,8 +321,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( in_avals=tuple(in_avals), out_tree=out_tree, out_avals=tuple(out_avals), - in_shardings=(pxla.UNSPECIFIED,) * len(in_avals), - out_shardings=(pxla.UNSPECIFIED,) * len(out_avals), + in_shardings=(UNSPECIFIED,) * len(in_avals), + out_shardings=(UNSPECIFIED,) * len(out_avals), lowering_platforms=(data.platform,), ordered_effects=(), unordered_effects=(), diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index c005e937f..a94fd780f 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -67,9 +67,8 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, - AUTO, UnspecifiedValue, UNSPECIFIED, - get_array_mapping as _get_array_mapping, is_auto, is_unspecified, - is_unspecified_or_auto + AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, + is_unspecified, is_unspecified_or_auto ) from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, @@ -1889,7 +1888,7 @@ def lower_sharding_computation( api_name: str, fun_name: str, in_shardings: Sequence[MaybeSharding], - out_shardings: Sequence[MaybeSharding] | UnspecifiedValue, + out_shardings: Sequence[MaybeSharding], donated_invars: Sequence[bool], global_in_avals: Sequence[core.ShapedArray], *, @@ -1898,7 +1897,7 @@ def lower_sharding_computation( devices_from_context: Sequence[xc.Device] | None = None, lowering_parameters: mlir.LoweringParameters, in_layouts: MaybeLayout, - out_layouts: Optional[MaybeLayout], + out_layouts: MaybeLayout, ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. @@ -1908,9 +1907,8 @@ def lower_sharding_computation( the singleton UNSPECIFIED to all out_avals. """ # 1. Trace to jaxpr and preprocess/verify it - auto_spmd_lowering = ( - check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else - check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore + auto_spmd_lowering = check_if_any_auto( + it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore all_args_info = AllArgsInfo(global_in_avals, in_shardings, closed_jaxpr.jaxpr.debug_info) @@ -1923,12 +1921,6 @@ def lower_sharding_computation( in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx) - if is_unspecified(out_shardings): - out_shardings = (UNSPECIFIED,) * len(global_out_avals) - if out_layouts is None: - out_layouts = (None,) * len(global_out_avals) - assert isinstance(out_shardings, tuple) - assert isinstance(out_layouts, tuple) assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( len(out_shardings), len(out_layouts), len(global_out_avals)) @@ -1978,7 +1970,7 @@ def lower_sharding_computation( # 2. Build up the HLO semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore - semantic_out_shardings = SemanticallyEqualShardings(out_shardings) + semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, diff --git a/jax/_src/maps.py b/jax/_src/maps.py index ec5d638c1..d41c32974 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -702,9 +702,9 @@ def make_xmap_callable(fun: lu.WrappedFun, tiling_method=tiling_method, lowering_parameters=lowering_parameters) else: - closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name) + closed_jaxpr, out_avals = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name) return dispatch.sharded_lowering( - closed_jaxpr, name, donated_invars, True, False, in_avals, + closed_jaxpr, name, donated_invars, True, False, in_avals, out_avals, (None,) * len(in_avals), lowering_parameters=lowering_parameters)