From 81aee237d80247b39f28de1310f9eaf8e728f39f Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 27 Nov 2023 18:00:22 -0800 Subject: [PATCH] Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py PiperOrigin-RevId: 585810475 --- jax/_src/dispatch.py | 20 ++++++++++++++++--- jax/_src/interpreters/pxla.py | 37 ++++++++++------------------------- jax/_src/maps.py | 5 +++-- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 52ee0ed58..5cb547348 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -45,6 +45,7 @@ from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.interpreters import pxla +from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_client as xc from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec @@ -131,6 +132,17 @@ def apply_primitive(prim, *args, **params): return compiled_fun(*args) +# No need to cache here because there is a cache on xla_primitive_callable. +# If that cache is broken, a new function will be created which will always +# break the cache on this function. +def _trace_to_jaxpr(fun, in_avals, api_name, fun_name): + with log_elapsed_time( + "Finished tracing + transforming {fun_name} in {elapsed_time} sec", + fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT): + jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals) + return core.ClosedJaxpr(jaxpr, consts) + + @util.cache() def xla_primitive_callable( prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree, @@ -142,11 +154,13 @@ def xla_primitive_callable( return out else: return out, + donated_invars = (False,) * len(in_avals) wrapped_fun = lu.wrap_init(prim_fun) flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree) + closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name) computation = sharded_lowering( - flat_fun, prim.name, donated_invars, keep_unused=False, + closed_jaxpr, prim.name, donated_invars, keep_unused=False, inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings, lowering_parameters=mlir.LoweringParameters()) compiled = computation.compile() @@ -163,7 +177,7 @@ def xla_primitive_callable( def sharded_lowering( - fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool], + closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool], keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...], in_shardings: Sequence[Sharding | None], lowering_parameters: mlir.LoweringParameters @@ -174,7 +188,7 @@ def sharded_lowering( # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. return pxla.lower_sharding_computation( - fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, + closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars, in_avals, keep_unused=keep_unused, inline=inline, devices_from_context=None, lowering_parameters=lowering_parameters, in_layouts=(None,) * len(in_avals), out_layouts=None) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 35efac070..96624feb1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1665,16 +1665,6 @@ def _get_and_check_device_assignment( MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue] -def cache_wrap(fn): - _wrapped_with_lu_cache = lu.cache(fn) - _wrapped_with_weakref_lru_cache = weakref_lru_cache(fn) - def wrapped(f, *args, **kwargs): - if isinstance(f, lu.WrappedFun): - return _wrapped_with_lu_cache(f, *args, **kwargs) - else: - return _wrapped_with_weakref_lru_cache(f, *args, **kwargs) - return wrapped - def prune_unused_inputs( jaxpr: core.Jaxpr, @@ -1686,22 +1676,15 @@ def prune_unused_inputs( return new_jaxpr, kept_const_idx, kept_var_idx -@cache_wrap -def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name, - keep_unused, donated_invars, auto_spmd_lowering): +@weakref_lru_cache +def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name, + keep_unused, donated_invars, auto_spmd_lowering): name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) - if isinstance(fun_or_jaxpr, lu.WrappedFun): - with dispatch.log_elapsed_time( - "Finished tracing + transforming {fun_name} in {elapsed_time} sec", - fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT): - jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( - fun_or_jaxpr, global_in_avals) - else: - assert isinstance(fun_or_jaxpr, core.ClosedJaxpr) - jaxpr = fun_or_jaxpr.jaxpr - global_out_avals = fun_or_jaxpr.out_avals - consts = fun_or_jaxpr.consts + assert isinstance(closed_jaxpr, core.ClosedJaxpr) + jaxpr = closed_jaxpr.jaxpr + global_out_avals = closed_jaxpr.out_avals + consts = closed_jaxpr.consts if (keep_unused or auto_spmd_lowering or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape) @@ -1894,7 +1877,7 @@ MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]] @profiler.annotate_function def lower_sharding_computation( - fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr, + closed_jaxpr: core.ClosedJaxpr, api_name: str, fun_name: str, in_shardings: Sequence[MaybeSharding], @@ -1922,8 +1905,8 @@ def lower_sharding_computation( check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, - kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce( - fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused, + kept_var_idx, name_stack) = _dce_jaxpr( + closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused, donated_invars, auto_spmd_lowering) jaxpr = closed_jaxpr.jaxpr in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 86d3bd740..ec5d638c1 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -702,9 +702,10 @@ def make_xmap_callable(fun: lu.WrappedFun, tiling_method=tiling_method, lowering_parameters=lowering_parameters) else: + closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name) return dispatch.sharded_lowering( - f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals), - lowering_parameters=lowering_parameters) + closed_jaxpr, name, donated_invars, True, False, in_avals, + (None,) * len(in_avals), lowering_parameters=lowering_parameters) class EvaluationPlan(NamedTuple):