Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py

PiperOrigin-RevId: 585810475
This commit is contained in:
Yash Katariya 2023-11-27 18:00:22 -08:00 committed by jax authors
parent 2ed0fc4d1c
commit 81aee237d8
3 changed files with 30 additions and 32 deletions

View File

@ -45,6 +45,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.interpreters import partial_eval as pe
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
@ -131,6 +132,17 @@ def apply_primitive(prim, *args, **params):
return compiled_fun(*args)
# No need to cache here because there is a cache on xla_primitive_callable.
# If that cache is broken, a new function will be created which will always
# break the cache on this function.
def _trace_to_jaxpr(fun, in_avals, api_name, fun_name):
with log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT):
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts)
@util.cache()
def xla_primitive_callable(
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree,
@ -142,11 +154,13 @@ def xla_primitive_callable(
return out
else:
return out,
donated_invars = (False,) * len(in_avals)
wrapped_fun = lu.wrap_init(prim_fun)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
computation = sharded_lowering(
flat_fun, prim.name, donated_invars, keep_unused=False,
closed_jaxpr, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
@ -163,7 +177,7 @@ def xla_primitive_callable(
def sharded_lowering(
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None],
lowering_parameters: mlir.LoweringParameters
@ -174,7 +188,7 @@ def sharded_lowering(
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=None)

View File

@ -1665,16 +1665,6 @@ def _get_and_check_device_assignment(
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
def cache_wrap(fn):
_wrapped_with_lu_cache = lu.cache(fn)
_wrapped_with_weakref_lru_cache = weakref_lru_cache(fn)
def wrapped(f, *args, **kwargs):
if isinstance(f, lu.WrappedFun):
return _wrapped_with_lu_cache(f, *args, **kwargs)
else:
return _wrapped_with_weakref_lru_cache(f, *args, **kwargs)
return wrapped
def prune_unused_inputs(
jaxpr: core.Jaxpr,
@ -1686,22 +1676,15 @@ def prune_unused_inputs(
return new_jaxpr, kept_const_idx, kept_var_idx
@cache_wrap
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars, auto_spmd_lowering):
@weakref_lru_cache
def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
keep_unused, donated_invars, auto_spmd_lowering):
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
if isinstance(fun_or_jaxpr, lu.WrappedFun):
with dispatch.log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun_or_jaxpr, global_in_avals)
else:
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
jaxpr = fun_or_jaxpr.jaxpr
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
jaxpr = closed_jaxpr.jaxpr
global_out_avals = closed_jaxpr.out_avals
consts = closed_jaxpr.consts
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
@ -1894,7 +1877,7 @@ MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
closed_jaxpr: core.ClosedJaxpr,
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
@ -1922,8 +1905,8 @@ def lower_sharding_computation(
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
kept_var_idx, name_stack) = _dce_jaxpr(
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)

View File

@ -702,9 +702,10 @@ def make_xmap_callable(fun: lu.WrappedFun,
tiling_method=tiling_method,
lowering_parameters=lowering_parameters)
else:
closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
return dispatch.sharded_lowering(
f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals),
lowering_parameters=lowering_parameters)
closed_jaxpr, name, donated_invars, True, False, in_avals,
(None,) * len(in_avals), lowering_parameters=lowering_parameters)
class EvaluationPlan(NamedTuple):