mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use trace_to_jaxpr_dynamic
for the apply_primitive path. trace_to_jaxpr_final
is only for final style primitives. Also do some cleanup.
PiperOrigin-RevId: 586106427
This commit is contained in:
parent
a66fea78b0
commit
cb7c2ed848
@ -139,8 +139,8 @@ def _trace_to_jaxpr(fun, in_avals, api_name, fun_name):
|
||||
with log_elapsed_time(
|
||||
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
|
||||
fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT):
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts)
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
|
||||
return core.ClosedJaxpr(jaxpr, consts), tuple(out_avals)
|
||||
|
||||
|
||||
@util.cache()
|
||||
@ -158,10 +158,11 @@ def xla_primitive_callable(
|
||||
donated_invars = (False,) * len(in_avals)
|
||||
wrapped_fun = lu.wrap_init(prim_fun)
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
|
||||
closed_jaxpr = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
|
||||
closed_jaxpr, out_avals = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
|
||||
computation = sharded_lowering(
|
||||
closed_jaxpr, prim.name, donated_invars, keep_unused=False,
|
||||
inline=True, in_avals=in_avals, in_shardings=orig_in_shardings.shardings,
|
||||
inline=True, in_avals=in_avals, out_avals=out_avals,
|
||||
in_shardings=orig_in_shardings.shardings,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
compiled = computation.compile()
|
||||
if config.disable_jit.value:
|
||||
@ -179,19 +180,17 @@ def xla_primitive_callable(
|
||||
def sharded_lowering(
|
||||
closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool],
|
||||
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
|
||||
out_avals: tuple[core.AbstractValue, ...],
|
||||
in_shardings: Sequence[Sharding | None],
|
||||
lowering_parameters: mlir.LoweringParameters
|
||||
) -> pxla.MeshComputation:
|
||||
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
|
||||
|
||||
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
closed_jaxpr, 'jit', name, in_shardings_unspec, UNSPECIFIED, donated_invars,
|
||||
closed_jaxpr, 'jit', name, in_shardings_unspec,
|
||||
(UNSPECIFIED,) * len(out_avals), donated_invars,
|
||||
in_avals, keep_unused=keep_unused, inline=inline,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters,
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=None)
|
||||
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
|
||||
|
||||
|
||||
def simple_impl(prim):
|
||||
|
@ -87,7 +87,7 @@ from jax.experimental import pjit
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.sharding_impls import UNSPECIFIED
|
||||
from jax._src import xla_bridge as xb
|
||||
|
||||
|
||||
@ -321,8 +321,8 @@ data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
|
||||
in_avals=tuple(in_avals),
|
||||
out_tree=out_tree,
|
||||
out_avals=tuple(out_avals),
|
||||
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
|
||||
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
|
||||
in_shardings=(UNSPECIFIED,) * len(in_avals),
|
||||
out_shardings=(UNSPECIFIED,) * len(out_avals),
|
||||
lowering_platforms=(data.platform,),
|
||||
ordered_effects=(),
|
||||
unordered_effects=(),
|
||||
|
@ -67,9 +67,8 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO, UnspecifiedValue, UNSPECIFIED,
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified,
|
||||
is_unspecified_or_auto
|
||||
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
is_unspecified, is_unspecified_or_auto
|
||||
)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
@ -1889,7 +1888,7 @@ def lower_sharding_computation(
|
||||
api_name: str,
|
||||
fun_name: str,
|
||||
in_shardings: Sequence[MaybeSharding],
|
||||
out_shardings: Sequence[MaybeSharding] | UnspecifiedValue,
|
||||
out_shardings: Sequence[MaybeSharding],
|
||||
donated_invars: Sequence[bool],
|
||||
global_in_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
@ -1898,7 +1897,7 @@ def lower_sharding_computation(
|
||||
devices_from_context: Sequence[xc.Device] | None = None,
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
in_layouts: MaybeLayout,
|
||||
out_layouts: Optional[MaybeLayout],
|
||||
out_layouts: MaybeLayout,
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
@ -1908,9 +1907,8 @@ def lower_sharding_computation(
|
||||
the singleton UNSPECIFIED to all out_avals.
|
||||
"""
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
auto_spmd_lowering = (
|
||||
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
|
||||
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
|
||||
auto_spmd_lowering = check_if_any_auto(
|
||||
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
|
||||
|
||||
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
|
||||
closed_jaxpr.jaxpr.debug_info)
|
||||
@ -1923,12 +1921,6 @@ def lower_sharding_computation(
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
|
||||
|
||||
if is_unspecified(out_shardings):
|
||||
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
|
||||
if out_layouts is None:
|
||||
out_layouts = (None,) * len(global_out_avals)
|
||||
assert isinstance(out_shardings, tuple)
|
||||
assert isinstance(out_layouts, tuple)
|
||||
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
|
||||
len(out_shardings), len(out_layouts), len(global_out_avals))
|
||||
|
||||
@ -1978,7 +1970,7 @@ def lower_sharding_computation(
|
||||
|
||||
# 2. Build up the HLO
|
||||
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
|
||||
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
|
||||
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
|
||||
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
|
||||
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
|
||||
|
@ -702,9 +702,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
tiling_method=tiling_method,
|
||||
lowering_parameters=lowering_parameters)
|
||||
else:
|
||||
closed_jaxpr = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
|
||||
closed_jaxpr, out_avals = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
|
||||
return dispatch.sharded_lowering(
|
||||
closed_jaxpr, name, donated_invars, True, False, in_avals,
|
||||
closed_jaxpr, name, donated_invars, True, False, in_avals, out_avals,
|
||||
(None,) * len(in_avals), lowering_parameters=lowering_parameters)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user