mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Simply lower_sharding_computation signature by always taking a closed jaxpr as input. For apply_primitive do the tracing to jaxpr in dispatch.py
PiperOrigin-RevId: 585810475
This commit is contained in:
parent
2ed0fc4d1c
commit
81aee237d8
@ -45,6 +45,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -131,6 +132,17 @@ def apply_primitive(prim, *args, **params):
|
||||
return compiled_fun(*args)
|
||||
|
||||
|
||||
# No need to cache here because there is a cache on xla_primitive_callable.
|
||||
# If that cache is broken, a new function will be created which will always
|
||||
# break the cache on this function.
|
||||
def _trace_to_jaxpr(fun, in_avals, api_name, fun_name):
|
||||
with log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
||||
fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
|
||||
|
||||
@util.cache()
|
||||
def xla_primitive_callable(
|
||||
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree,
|
||||
@ -142,11 +154,13 @@ def xla_primitive_callable(
|
||||
return out
|
||||
else:
|
||||
return out,
|
||||
|
||||
donated_invars = (False,) * len(in_avals)
|
||||
wrapped_fun = lu.wrap_init(prim_fun)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
|
||||
closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
|
||||
computation = sharded_lowering(
|
||||
flat_fun, prim.name, donated_invars, keep_unused=False,
|
||||
closed_jaxpr, prim.name, donated_invars, keep_unused=False,
|
||||
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
compiled = computation.compile()
|
||||
@ -163,7 +177,7 @@ def xla_primitive_callable(
|
||||
|
||||
|
||||
def sharded_lowering(
|
||||
fun: lu.WrappedFun, name: str, donated_invars: Sequence[bool],
|
||||
closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool],
|
||||
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
|
||||
in_shardings: Sequence[Sharding | None],
|
||||
lowering_parameters: mlir.LoweringParameters
|
||||
@ -174,7 +188,7 @@ def sharded_lowering(
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters,
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=None)
|
||||
|
@ -1665,16 +1665,6 @@ def _get_and_check_device_assignment(
|
||||
|
||||
MaybeSharding = Union[sharding_impls.XLACompatibleSharding, UnspecifiedValue]
|
||||
|
||||
def cache_wrap(fn):
|
||||
_wrapped_with_lu_cache = lu.cache(fn)
|
||||
_wrapped_with_weakref_lru_cache = weakref_lru_cache(fn)
|
||||
def wrapped(f, *args, **kwargs):
|
||||
if isinstance(f, lu.WrappedFun):
|
||||
return _wrapped_with_lu_cache(f, *args, **kwargs)
|
||||
else:
|
||||
return _wrapped_with_weakref_lru_cache(f, *args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
def prune_unused_inputs(
|
||||
jaxpr: core.Jaxpr,
|
||||
@ -1686,22 +1676,15 @@ def prune_unused_inputs(
|
||||
return new_jaxpr, kept_const_idx, kept_var_idx
|
||||
|
||||
|
||||
@cache_wrap
|
||||
def _trace_to_jaxpr_and_dce(fun_or_jaxpr, global_in_avals, api_name, fun_name,
|
||||
keep_unused, donated_invars, auto_spmd_lowering):
|
||||
@weakref_lru_cache
|
||||
def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
|
||||
keep_unused, donated_invars, auto_spmd_lowering):
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
if isinstance(fun_or_jaxpr, lu.WrappedFun):
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun_or_jaxpr, global_in_avals)
|
||||
else:
|
||||
assert isinstance(fun_or_jaxpr, core.ClosedJaxpr)
|
||||
jaxpr = fun_or_jaxpr.jaxpr
|
||||
global_out_avals = fun_or_jaxpr.out_avals
|
||||
consts = fun_or_jaxpr.consts
|
||||
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
global_out_avals = closed_jaxpr.out_avals
|
||||
consts = closed_jaxpr.consts
|
||||
|
||||
if (keep_unused or auto_spmd_lowering or
|
||||
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
|
||||
@ -1894,7 +1877,7 @@ MaybeLayout = Sequence[Optional[Union[XLACompatibleLayout, LayoutRequest]]]
|
||||
|
||||
@profiler.annotate_function
|
||||
def lower_sharding_computation(
|
||||
fun_or_jaxpr: lu.WrappedFun | core.ClosedJaxpr,
|
||||
closed_jaxpr: core.ClosedJaxpr,
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[MaybeSharding],
|
||||
@ -1922,8 +1905,8 @@ def lower_sharding_computation(
|
||||
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
|
||||
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
|
||||
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
||||
kept_var_idx, name_stack) = _dce_jaxpr(
|
||||
closed_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
|
||||
donated_invars, auto_spmd_lowering)
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
|
@ -702,9 +702,10 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
tiling_method=tiling_method,
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
|
||||
return dispatch.sharded_lowering(
|
||||
f, name, donated_invars, True, False, in_avals, (None,) * len(in_avals),
|
||||
lowering_parameters=lowering_parameters)
|
||||
closed_jaxpr, name, donated_invars, True, False, in_avals,
|
||||
(None,) * len(in_avals), lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
Loading…
x
Reference in New Issue
Block a user