Add the function name, the Jaxpr, and lowering platforms to Lowered.

These changes are necessary to ensure that `Lowered` carries all the
information that is needed for export and serialization.
These are in preparation of a cleanup of the exporting and serialization APIs
to integrate them with the AOT APIs. In particular, exporting will start
with a `Lowered` object and will not include anymore its own lowering code.

We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`)
to `Lowered`,
and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`).

The function name is useful for better error messages when exporting and
serializating. The Jaxpr is useful for exporting also the VJP of the function
and obtaining an `Exported` that can be differentiated.
This commit is contained in:
George Necula 2024-05-14 17:09:04 +03:00
parent 4fae9aa160
commit bb4c073574
6 changed files with 178 additions and 141 deletions

View File

@ -1846,7 +1846,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
devices, backend, axis_size, args, kwargs)
abstract_args = list(map(shaped_abstractify, p.flat_args))
computation = pxla.lower_parallel_callable(
computation, closed_jaxpr = pxla.lower_parallel_callable(
p.flat_fun, backend, axis_name,
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
devices=p.devices,
@ -1858,7 +1858,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
avals=abstract_args,
lowering_parameters=lowering_parameters)
return stages.Lowered.from_flat_info(
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(),
fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr)
return lower

View File

@ -556,7 +556,7 @@ def parallel_callable(fun: lu.WrappedFun,
donated_invars: Sequence[bool],
is_explicit_global_axis_size: bool,
*avals):
pmap_computation = lower_parallel_callable(
pmap_computation, _ = lower_parallel_callable(
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
is_explicit_global_axis_size, avals,
@ -679,7 +679,7 @@ def lower_parallel_callable(
is_explicit_global_axis_size: bool,
avals: Sequence[core.AbstractValue],
*,
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]:
# Determine global_axis_size for use in AxisEnv.
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
@ -761,6 +761,7 @@ def lower_parallel_callable(
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
backend.platform)
module_name = f"pmap_{fun.__name__}"
platforms = lowering_parameters.platforms or (backend.platform,)
with maybe_extend_axis_env(axis_name, global_axis_size, None):
ordered_effects = list(
effects.ordered_effects.filter_in(closed_jaxpr.effects))
@ -776,7 +777,7 @@ def lower_parallel_callable(
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
platforms=lowering_parameters.platforms or (backend.platform,),
platforms=platforms,
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
name_stack=name_stack,
donated_args=donated_invars,
@ -787,14 +788,16 @@ def lower_parallel_callable(
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=replicas.num_global_replicas,
lowering_parameters=lowering_parameters)
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
return PmapComputation(lowering_result.module,
platforms=platforms,
pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=lowering_result.keepalive,
host_callbacks=lowering_result.host_callbacks,
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
shape_poly_state=lowering_result.shape_poly_state)
shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr
def _pmap_unmap_shaped_array(
@ -907,10 +910,13 @@ class UnloadedPmapExecutable:
host_callbacks: list[Any],
keepalive: Any,
jaxpr_debug_info: core.JaxprDebugInfo,
platforms: Sequence[str],
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
compiler_options=None):
del platforms
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
devices = pci.devices
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):
@ -1941,7 +1947,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
"The following ordered effects are not supported for "
f"more than 1 device: {unsupported_effects}")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
@ -2141,6 +2146,7 @@ def lower_sharding_computation(
for js, source_info in util.stable_unique(jaxpr_sharding))),
devices_from_context)
platforms = lowering_parameters.platforms or (backend.platform,)
# TODO(yashkatariya): Enable this when offload APIs are stable.
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
@ -2204,6 +2210,7 @@ def lower_sharding_computation(
kept_var_idx=kept_var_idx,
mut=mut,
backend=backend,
platforms=platforms,
device_assignment=da_object,
committed=committed,
in_layouts=in_layouts,
@ -2244,6 +2251,7 @@ def lower_mesh_computation(
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
platforms = lowering_parameters.platforms or (backend.platform,)
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
global_axis_sizes = mesh.shape
@ -2352,7 +2360,7 @@ def lower_mesh_computation(
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
platforms=lowering_parameters.platforms or (backend.platform,),
platforms=platforms,
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,
@ -2382,6 +2390,7 @@ def lower_mesh_computation(
keepalive=lowering_result.keepalive,
kept_var_idx=set(range(len(global_in_avals))),
backend=backend,
platforms=platforms,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
in_layouts=(None,) * len(global_in_avals),
@ -2394,10 +2403,14 @@ class MeshComputation(stages.XlaLowering):
_executable: MeshExecutable | None
def __init__(self, name: str, hlo: ir.Module,
donated_invars: Sequence[bool], **compile_args):
donated_invars: Sequence[bool],
platforms: Sequence[str] | None = None, # None only for backwards
# compatibility with PartIR
**compile_args):
self._name = name
self._hlo = hlo
self._donated_invars = donated_invars
self._platforms = platforms
self.compile_args = compile_args
self._executable = None

View File

@ -617,7 +617,7 @@ def xmap(fun: Callable,
'_experimental_lowering_platform', mlir.LoweringParameters())
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
computation = make_xmap_callable(
computation, jaxpr = make_xmap_callable(
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
params['resource_env'], params['backend'], params['spmd_in_axes'],
@ -628,7 +628,7 @@ def xmap(fun: Callable,
in_avals = in_tree.unflatten(avals_flat)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree(),
no_kwargs=True)
no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr)
fun_mapped.lower = lower
return type_cast(stages.Wrapped, fun_mapped)
@ -637,11 +637,12 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
global_axis_sizes, axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk):
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
xmap_callable = make_xmap_callable(
computation, _ = make_xmap_callable(
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
axis_resources, resource_env, backend,
spmd_in_axes, spmd_out_axes_thunk,
mlir.LoweringParameters(), *in_avals).compile().unsafe_call
mlir.LoweringParameters(), *in_avals)
xmap_callable = computation.compile().unsafe_call
distributed_debug_log(("Running xmapped function", name),
("python function", fun.f),
("mesh", resource_env.physical_mesh),
@ -708,7 +709,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
in_shardings, out_shardings, donated_invars,
use_spmd_lowering, in_avals,
tiling_method=tiling_method,
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters), jaxpr
else:
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
return pxla.lower_sharding_computation(
@ -716,7 +717,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
(None,) * len(in_avals), (None,) * len(out_avals),
donated_invars, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters)
devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr
class EvaluationPlan(NamedTuple):

View File

@ -469,7 +469,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
return stages.Lowered.from_flat_info(
lowering, in_tree, flat_global_in_avals, donate_argnums,
out_tree)
out_tree, fun_name=params["name"], jaxpr=params["jaxpr"])
@api_boundary
def eval_shape(*args, **kwargs):

View File

@ -601,23 +601,29 @@ class Lowered(Stage):
querying properties of lowered computations across JAX's various
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
"""
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"]
_lowering: XlaLowering
args_info: Any # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef
_lowering: XlaLowering
_no_kwargs: bool
_fun_name: str
_jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed
# outside of JAX core.
def __init__(
self,
lowering: XlaLowering,
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
no_kwargs: bool = False,
fun_name: str = "unknown",
jaxpr: core.ClosedJaxpr | None = None):
self._lowering = lowering
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree
self._fun_name = fun_name
self._jaxpr = jaxpr
@classmethod
def from_flat_info(cls,
@ -626,7 +632,9 @@ class Lowered(Stage):
in_avals,
donate_argnums: tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False):
no_kwargs: bool = False,
fun_name: str = "unknown",
jaxpr: core.ClosedJaxpr | None = None):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Args:
@ -635,12 +643,14 @@ class Lowered(Stage):
no_kwargs: If ``True`` the transformation, and the
``Compiled`` returned from this object will not support keyword
arguments (an error will be raised if some are provided).
fun_name: the name of the lowered function, if available.
jaxpr: the Jaxpr of the lowered function, if available.
"""
return cls(
lowering,
make_args_info(in_tree, in_avals, donate_argnums),
out_tree,
no_kwargs=no_kwargs)
no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr)
def compile(
self, compiler_options: CompilerOptions | None = None) -> Compiled:

View File

@ -47,6 +47,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -374,14 +375,6 @@ def export(fun_jax: Callable,
def f_jax(*args, **kwargs): ...
exported = jax_export.export(f_jax)(*args, **kwargs)
"""
fun_name = getattr(fun_jax, "__name__", "unknown")
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")
def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
@ -402,7 +395,7 @@ def export(fun_jax: Callable,
# TODO: move to `lower`
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
# Static args may has no `shape` attribute.
# Static args may have no `shape` attribute.
if not hasattr(aval, "shape"):
continue
for d in aval.shape:
@ -411,7 +404,7 @@ def export(fun_jax: Callable,
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {fun_name}",
d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))
@ -420,6 +413,20 @@ def export(fun_jax: Callable,
_experimental_lowering_parameters=mlir.LoweringParameters(
platforms=actual_lowering_platforms,
))
return _export_lowered(lowered, disabled_checks=disabled_checks)
return do_export
def _export_lowered(
lowered: stages.Lowered,
disabled_checks: Sequence[DisabledSafetyCheck] = (),
) -> Exported:
version = config.jax_serialization_version.value
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")
lowering = lowered._lowering
_check_lowering(lowering)
@ -461,7 +468,7 @@ def export(fun_jax: Callable,
# Log and then check the module.
if logging.vlog_is_on(3):
logmsg = (f"version={version} "
f"lowering_platforms={actual_lowering_platforms} "
f"lowering_platforms={lowering.compile_args['platforms']} "
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
@ -489,8 +496,29 @@ def export(fun_jax: Callable,
out_shardings = tuple(
export_sharding(s, aval)
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
device_assignment = lowering.compile_args["device_assignment"]
def _get_exported_vjp(exp_primal: Exported) -> Exported:
# Turn the primal jaxpr into a function, in preparation for exporting
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments
assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX
fun_jax = core.jaxpr_as_fun(lowered._jaxpr)
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax,
in_tree=exp_primal.in_tree,
in_avals=exp_primal.in_avals,
in_shardings=exp_primal.in_shardings,
out_avals=exp_primal.out_avals,
out_shardings=exp_primal.out_shardings,
device_assignment=device_assignment,
apply_jit=True,
flat_primal_fun=True)
return export(fun_vjp_jax,
lowering_platforms=exp_primal.lowering_platforms,
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
return Exported(
fun_name=fun_name,
fun_name=lowered._fun_name,
in_tree=lowered.in_tree,
out_tree=lowered.out_tree,
in_avals=tuple(args_avals_flat),
@ -498,7 +526,7 @@ def export(fun_jax: Callable,
in_shardings=in_shardings,
out_shardings=out_shardings,
nr_devices=nr_devices,
lowering_platforms=actual_lowering_platforms,
lowering_platforms=lowering._platforms,
ordered_effects=ordered_effects,
unordered_effects=unordered_effects,
disabled_safety_checks=tuple(disabled_checks),
@ -506,11 +534,7 @@ def export(fun_jax: Callable,
module_kept_var_idx=module_kept_var_idx,
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
mlir_module_serialization_version=version,
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported,
lowering.compile_args["device_assignment"]))
return do_export
_get_vjp=_get_exported_vjp)
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
@ -713,7 +737,7 @@ def _check_lowering(lowering) -> None:
# safe to add it to the allowed_compile_args if it does not change the semantics
# or the calling convention of the lowered module.
allowed_compile_args = [
"backend", "mesh", "global_in_avals",
"backend", "platforms", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"mut", "spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
@ -918,12 +942,15 @@ def _get_vjp_fun(primal_fun: Callable, *,
in_shardings: tuple[Sharding, ...],
out_shardings: tuple[Sharding, ...],
device_assignment: Sequence[sharding_impls.Device] | None,
apply_jit: bool
apply_jit: bool,
flat_primal_fun: bool = False,
) -> tuple[Callable, Sequence[core.AbstractValue]]:
# Since jax.vjp does not handle kwargs, it is easier to do all the work
# here with flattened functions.
# apply_jit=False is only used for backwards compatibility with the graph
# graph serialization. When apply_jit=True, we must pass a device assignment.
# flat_primal_fun=False is used only from jax2tf, and it means that the
# `primal_fun` takes PyTree `*args` and `**kwargs`.
def fun_vjp_jax(*args_and_out_cts_flat_jax):
# Takes a flat list of primals and output cotangents
def flattened_primal_fun_jax(*args_flat):
@ -934,7 +961,8 @@ def _get_vjp_fun(primal_fun: Callable, *,
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
[len(in_avals)])
_, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax)
_, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax,
*args_flat_jax)
return pullback_jax(out_cts_flat_jax)
vjp_in_avals = list(
@ -953,22 +981,6 @@ def _get_vjp_fun(primal_fun: Callable, *,
else:
return fun_vjp_jax, vjp_in_avals
def _export_native_vjp(primal_fun,
primal: Exported,
device_assignment: Sequence[sharding_impls.Device]) -> Exported:
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
in_tree=primal.in_tree,
in_avals=primal.in_avals,
in_shardings=primal.in_shardings,
out_avals=primal.out_avals,
out_shardings=primal.out_shardings,
device_assignment=device_assignment,
apply_jit=True)
return export(fun_vjp_jax,
lowering_platforms=primal.lowering_platforms,
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)
### Calling the exported function
def call(exported: Exported) -> Callable[..., jax.Array]: