From bb4c073574cddef2b84c4a8af4f0201e9af816ed Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 14 May 2024 17:09:04 +0300 Subject: [PATCH] Add the function name, the Jaxpr, and lowering platforms to Lowered. These changes are necessary to ensure that `Lowered` carries all the information that is needed for export and serialization. These are in preparation of a cleanup of the exporting and serialization APIs to integrate them with the AOT APIs. In particular, exporting will start with a `Lowered` object and will not include anymore its own lowering code. We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`) to `Lowered`, and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`). The function name is useful for better error messages when exporting and serializating. The Jaxpr is useful for exporting also the VJP of the function and obtaining an `Exported` that can be differentiated. --- jax/_src/api.py | 5 +- jax/_src/interpreters/pxla.py | 29 +++- jax/_src/maps.py | 13 +- jax/_src/pjit.py | 2 +- jax/_src/stages.py | 22 ++- jax/experimental/export/_export.py | 248 +++++++++++++++-------------- 6 files changed, 178 insertions(+), 141 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index ec08c5c12..3c4cfcea4 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1846,7 +1846,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) - computation = pxla.lower_parallel_callable( + computation, closed_jaxpr = pxla.lower_parallel_callable( p.flat_fun, backend, axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, @@ -1858,7 +1858,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, avals=abstract_args, lowering_parameters=lowering_parameters) return stages.Lowered.from_flat_info( - computation, p.in_tree, abstract_args, donate_tuple, p.out_tree()) + computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(), + fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr) return lower diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 3fb6e41eb..2ccafbaa7 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -556,7 +556,7 @@ def parallel_callable(fun: lu.WrappedFun, donated_invars: Sequence[bool], is_explicit_global_axis_size: bool, *avals): - pmap_computation = lower_parallel_callable( + pmap_computation, _ = lower_parallel_callable( fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, is_explicit_global_axis_size, avals, @@ -679,7 +679,7 @@ def lower_parallel_callable( is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, - lowering_parameters: mlir.LoweringParameters) -> PmapComputation: + lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]: # Determine global_axis_size for use in AxisEnv. # TODO(mattjj,skyewm): revive this check (inner_pmap always False now) # if xb.process_count() > 1 and global_axis_size is None and inner_pmap: @@ -761,6 +761,7 @@ def lower_parallel_callable( tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), backend.platform) module_name = f"pmap_{fun.__name__}" + platforms = lowering_parameters.platforms or (backend.platform,) with maybe_extend_axis_env(axis_name, global_axis_size, None): ordered_effects = list( effects.ordered_effects.filter_in(closed_jaxpr.effects)) @@ -776,7 +777,7 @@ def lower_parallel_callable( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=sharding_impls.ReplicaAxisContext(axis_env), name_stack=name_stack, donated_args=donated_invars, @@ -787,14 +788,16 @@ def lower_parallel_callable( result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) - return PmapComputation(lowering_result.module, pci=pci, replicas=replicas, + return PmapComputation(lowering_result.module, + platforms=platforms, + pci=pci, replicas=replicas, shards=shards, tuple_args=tuple_args, unordered_effects=unordered_effects, ordered_effects=ordered_effects, keepalive=lowering_result.keepalive, host_callbacks=lowering_result.host_callbacks, jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info, - shape_poly_state=lowering_result.shape_poly_state) + shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr def _pmap_unmap_shaped_array( @@ -907,10 +910,13 @@ class UnloadedPmapExecutable: host_callbacks: list[Any], keepalive: Any, jaxpr_debug_info: core.JaxprDebugInfo, + platforms: Sequence[str], shape_poly_state: mlir.ShapePolyLoweringState | None = None, compiler_options=None): + del platforms if shape_poly_state is not None and shape_poly_state.uses_dim_vars: hlo = mlir.refine_polymorphic_shapes(hlo) + devices = pci.devices if devices is None: if shards.num_global_shards > xb.device_count(pci.backend): @@ -1941,7 +1947,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, "The following ordered effects are not supported for " f"more than 1 device: {unsupported_effects}") ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) - with dispatch.log_elapsed_time( "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): @@ -2141,6 +2146,7 @@ def lower_sharding_computation( for js, source_info in util.stable_unique(jaxpr_sharding))), devices_from_context) + platforms = lowering_parameters.platforms or (backend.platform,) # TODO(yashkatariya): Enable this when offload APIs are stable. # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) @@ -2204,6 +2210,7 @@ def lower_sharding_computation( kept_var_idx=kept_var_idx, mut=mut, backend=backend, + platforms=platforms, device_assignment=da_object, committed=committed, in_layouts=in_layouts, @@ -2244,6 +2251,7 @@ def lower_mesh_computation( lowering_parameters: mlir.LoweringParameters) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) + platforms = lowering_parameters.platforms or (backend.platform,) name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) global_axis_sizes = mesh.shape @@ -2352,7 +2360,7 @@ def lower_mesh_computation( closed_jaxpr, ordered_effects=ordered_effects, backend_or_name=backend, - platforms=lowering_parameters.platforms or (backend.platform,), + platforms=platforms, axis_context=axis_ctx, name_stack=name_stack, donated_args=donated_invars, @@ -2382,6 +2390,7 @@ def lower_mesh_computation( keepalive=lowering_result.keepalive, kept_var_idx=set(range(len(global_in_avals))), backend=backend, + platforms=platforms, device_assignment=_create_da_object(tuple(mesh.devices.flat)), committed=True, in_layouts=(None,) * len(global_in_avals), @@ -2394,10 +2403,14 @@ class MeshComputation(stages.XlaLowering): _executable: MeshExecutable | None def __init__(self, name: str, hlo: ir.Module, - donated_invars: Sequence[bool], **compile_args): + donated_invars: Sequence[bool], + platforms: Sequence[str] | None = None, # None only for backwards + # compatibility with PartIR + **compile_args): self._name = name self._hlo = hlo self._donated_invars = donated_invars + self._platforms = platforms self.compile_args = compile_args self._executable = None diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 595d86b58..2351c39a4 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -617,7 +617,7 @@ def xmap(fun: Callable, '_experimental_lowering_platform', mlir.LoweringParameters()) fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args) avals_flat = [shaped_abstractify(arg) for arg in args_flat] - computation = make_xmap_callable( + computation, jaxpr = make_xmap_callable( fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'], params['donated_invars'], params['global_axis_sizes'], params['axis_resources'], params['resource_env'], params['backend'], params['spmd_in_axes'], @@ -628,7 +628,7 @@ def xmap(fun: Callable, in_avals = in_tree.unflatten(avals_flat) return stages.Lowered.from_flat_info( computation, in_tree, in_avals, donate_argnums, out_tree(), - no_kwargs=True) + no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr) fun_mapped.lower = lower return type_cast(stages.Wrapped, fun_mapped) @@ -637,11 +637,12 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_ global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk): in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args] - xmap_callable = make_xmap_callable( + computation, _ = make_xmap_callable( fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, - mlir.LoweringParameters(), *in_avals).compile().unsafe_call + mlir.LoweringParameters(), *in_avals) + xmap_callable = computation.compile().unsafe_call distributed_debug_log(("Running xmapped function", name), ("python function", fun.f), ("mesh", resource_env.physical_mesh), @@ -708,7 +709,7 @@ def make_xmap_callable(fun: lu.WrappedFun, in_shardings, out_shardings, donated_invars, use_spmd_lowering, in_avals, tiling_method=tiling_method, - lowering_parameters=lowering_parameters) + lowering_parameters=lowering_parameters), jaxpr else: jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals) return pxla.lower_sharding_computation( @@ -716,7 +717,7 @@ def make_xmap_callable(fun: lu.WrappedFun, (UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals), (None,) * len(in_avals), (None,) * len(out_avals), donated_invars, keep_unused=True, inline=False, - devices_from_context=None, lowering_parameters=lowering_parameters) + devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr class EvaluationPlan(NamedTuple): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 43bc77836..88f288168 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -469,7 +469,7 @@ def _make_jit_wrapper(jit_info: PjitInfo): donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d) return stages.Lowered.from_flat_info( lowering, in_tree, flat_global_in_avals, donate_argnums, - out_tree) + out_tree, fun_name=params["name"], jaxpr=params["jaxpr"]) @api_boundary def eval_shape(*args, **kwargs): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 683f2fcdc..c9eb8e64a 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -601,23 +601,29 @@ class Lowered(Stage): querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ - __slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"] - + __slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"] + _lowering: XlaLowering args_info: Any # PyTree of ArgInfo out_tree: tree_util.PyTreeDef - _lowering: XlaLowering _no_kwargs: bool + _fun_name: str + _jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed + # outside of JAX core. def __init__( self, lowering: XlaLowering, args_info, # PyTree of ArgInfo out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False): + no_kwargs: bool = False, + fun_name: str = "unknown", + jaxpr: core.ClosedJaxpr | None = None): self._lowering = lowering self._no_kwargs = no_kwargs self.args_info = args_info self.out_tree = out_tree + self._fun_name = fun_name + self._jaxpr = jaxpr @classmethod def from_flat_info(cls, @@ -626,7 +632,9 @@ class Lowered(Stage): in_avals, donate_argnums: tuple[int, ...], out_tree: tree_util.PyTreeDef, - no_kwargs: bool = False): + no_kwargs: bool = False, + fun_name: str = "unknown", + jaxpr: core.ClosedJaxpr | None = None): """Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef. Args: @@ -635,12 +643,14 @@ class Lowered(Stage): no_kwargs: If ``True`` the transformation, and the ``Compiled`` returned from this object will not support keyword arguments (an error will be raised if some are provided). + fun_name: the name of the lowered function, if available. + jaxpr: the Jaxpr of the lowered function, if available. """ return cls( lowering, make_args_info(in_tree, in_avals, donate_argnums), out_tree, - no_kwargs=no_kwargs) + no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr) def compile( self, compiler_options: CompilerOptions | None = None) -> Compiled: diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 1689bdc87..660967a98 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -47,6 +47,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src import pjit from jax._src import sharding_impls from jax._src import source_info_util +from jax._src import stages from jax._src import tree_util from jax._src import util from jax._src import xla_bridge as xb @@ -374,14 +375,6 @@ def export(fun_jax: Callable, def f_jax(*args, **kwargs): ... exported = jax_export.export(f_jax)(*args, **kwargs) """ - fun_name = getattr(fun_jax, "__name__", "unknown") - version = config.jax_serialization_version.value - if (version < minimum_supported_serialization_version or - version > maximum_supported_serialization_version): - raise ValueError( - f"The requested jax_serialization version {version} is outside the " - f"range of supported versions [{minimum_supported_serialization_version}" - f"..{maximum_supported_serialization_version}]") def do_export(*args_specs, **kwargs_specs) -> Exported: if not hasattr(fun_jax, "lower"): @@ -402,7 +395,7 @@ def export(fun_jax: Callable, # TODO: move to `lower` symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]: - # Static args may has no `shape` attribute. + # Static args may have no `shape` attribute. if not hasattr(aval, "shape"): continue for d in aval.shape: @@ -411,7 +404,7 @@ def export(fun_jax: Callable, symbolic_scope = (d.scope, k_path) continue symbolic_scope[0]._check_same_scope( - d, when=f"when exporting {fun_name}", + d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}", self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ", other_descr=_shape_poly.args_kwargs_path_to_str(k_path)) @@ -420,97 +413,128 @@ def export(fun_jax: Callable, _experimental_lowering_parameters=mlir.LoweringParameters( platforms=actual_lowering_platforms, )) - - lowering = lowered._lowering - _check_lowering(lowering) - mlir_module = lowering.stablehlo() - - args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "mut" in lowering.compile_args: - if lowering.compile_args["mut"]: raise NotImplementedError - if "kept_var_idx" in lowering.compile_args: - module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) - else: - # For pmap - module_kept_var_idx = tuple(range(len(args_avals_flat))) - shape_poly_state = lowering.compile_args["shape_poly_state"] - if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) - or lowering.compile_args.get("ordered_effects", [])): - mlir_module = _wrap_main_func( - mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, - has_platform_index_argument=shape_poly_state.has_platform_index_argument, - module_kept_var_idx=module_kept_var_idx, - serialization_version=version) - - with mlir_module.context: - mlir_module_attrs = mlir_module.operation.attributes - mlir_module_attrs["jax.uses_shape_polymorphism"] = ( - mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) - - mlir_module_serialized = _module_to_bytecode(mlir_module) - - # Figure out the result types and shapes - if "global_out_avals" in lowering.compile_args: - # This is currently the case for pjit - out_avals_flat = lowering.compile_args["global_out_avals"] - elif "shards" in lowering.compile_args: # for PmapComputation - out_avals_flat = lowering.compile_args["shards"].out_sharded_avals - else: - out_avals_flat = lowered.compile_args["out_avals"] - - # Log and then check the module. - if logging.vlog_is_on(3): - logmsg = (f"version={version} " - f"lowering_platforms={actual_lowering_platforms} " - f"disabled_checks={disabled_checks}") - logging.info("Lowered JAX module: %s\n", logmsg) - if dumped_to := mlir.dump_module_to_file(mlir_module, "export"): - logging.info("Dumped the exported MLIR module to %s", dumped_to) - - _check_module(mlir_module, - disabled_checks=disabled_checks) - - ordered_effects = tuple(lowering.compile_args["ordered_effects"]) - unordered_effects = tuple(lowering.compile_args["unordered_effects"]) - - nr_devices = len(lowering.compile_args["device_assignment"]) - def export_sharding(s: LoweringSharding, - aval: core.ShapedArray) -> Sharding: - if sharding_impls.is_unspecified(s): - return None - return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] - - all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], - module_kept_var_idx, - len(args_avals_flat)) - in_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(all_in_shardings, args_avals_flat)) - out_shardings = tuple( - export_sharding(s, aval) - for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) - return Exported( - fun_name=fun_name, - in_tree=lowered.in_tree, - out_tree=lowered.out_tree, - in_avals=tuple(args_avals_flat), - out_avals=tuple(out_avals_flat), - in_shardings=in_shardings, - out_shardings=out_shardings, - nr_devices=nr_devices, - lowering_platforms=actual_lowering_platforms, - ordered_effects=ordered_effects, - unordered_effects=unordered_effects, - disabled_safety_checks=tuple(disabled_checks), - mlir_module_serialized=mlir_module_serialized, - module_kept_var_idx=module_kept_var_idx, - uses_shape_polymorphism=shape_poly_state.uses_dim_vars, - mlir_module_serialization_version=version, - _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported, - lowering.compile_args["device_assignment"])) - + return _export_lowered(lowered, disabled_checks=disabled_checks) return do_export +def _export_lowered( + lowered: stages.Lowered, + disabled_checks: Sequence[DisabledSafetyCheck] = (), + ) -> Exported: + version = config.jax_serialization_version.value + if (version < minimum_supported_serialization_version or + version > maximum_supported_serialization_version): + raise ValueError( + f"The requested jax_serialization version {version} is outside the " + f"range of supported versions [{minimum_supported_serialization_version}" + f"..{maximum_supported_serialization_version}]") + + lowering = lowered._lowering + _check_lowering(lowering) + mlir_module = lowering.stablehlo() + + args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) + if "mut" in lowering.compile_args: + if lowering.compile_args["mut"]: raise NotImplementedError + if "kept_var_idx" in lowering.compile_args: + module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) + else: + # For pmap + module_kept_var_idx = tuple(range(len(args_avals_flat))) + shape_poly_state = lowering.compile_args["shape_poly_state"] + if (not all(core.is_constant_shape(a.shape) for a in args_avals_flat) + or lowering.compile_args.get("ordered_effects", [])): + mlir_module = _wrap_main_func( + mlir_module, args_avals_flat, args_kwargs_tree=lowered.in_tree, + has_platform_index_argument=shape_poly_state.has_platform_index_argument, + module_kept_var_idx=module_kept_var_idx, + serialization_version=version) + + with mlir_module.context: + mlir_module_attrs = mlir_module.operation.attributes + mlir_module_attrs["jax.uses_shape_polymorphism"] = ( + mlir.ir.BoolAttr.get(shape_poly_state.uses_dim_vars)) + + mlir_module_serialized = _module_to_bytecode(mlir_module) + + # Figure out the result types and shapes + if "global_out_avals" in lowering.compile_args: + # This is currently the case for pjit + out_avals_flat = lowering.compile_args["global_out_avals"] + elif "shards" in lowering.compile_args: # for PmapComputation + out_avals_flat = lowering.compile_args["shards"].out_sharded_avals + else: + out_avals_flat = lowered.compile_args["out_avals"] + + # Log and then check the module. + if logging.vlog_is_on(3): + logmsg = (f"version={version} " + f"lowering_platforms={lowering.compile_args['platforms']} " + f"disabled_checks={disabled_checks}") + logging.info("Lowered JAX module: %s\n", logmsg) + if dumped_to := mlir.dump_module_to_file(mlir_module, "export"): + logging.info("Dumped the exported MLIR module to %s", dumped_to) + + _check_module(mlir_module, + disabled_checks=disabled_checks) + + ordered_effects = tuple(lowering.compile_args["ordered_effects"]) + unordered_effects = tuple(lowering.compile_args["unordered_effects"]) + + nr_devices = len(lowering.compile_args["device_assignment"]) + def export_sharding(s: LoweringSharding, + aval: core.ShapedArray) -> Sharding: + if sharding_impls.is_unspecified(s): + return None + return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr] + + all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"], + module_kept_var_idx, + len(args_avals_flat)) + in_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(all_in_shardings, args_avals_flat)) + out_shardings = tuple( + export_sharding(s, aval) + for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat)) + + device_assignment = lowering.compile_args["device_assignment"] + def _get_exported_vjp(exp_primal: Exported) -> Exported: + # Turn the primal jaxpr into a function, in preparation for exporting + # the VJP. Note that jaxpr_as_fun produces a function with flat arguments + assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX + fun_jax = core.jaxpr_as_fun(lowered._jaxpr) + + fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax, + in_tree=exp_primal.in_tree, + in_avals=exp_primal.in_avals, + in_shardings=exp_primal.in_shardings, + out_avals=exp_primal.out_avals, + out_shardings=exp_primal.out_shardings, + device_assignment=device_assignment, + apply_jit=True, + flat_primal_fun=True) + return export(fun_vjp_jax, + lowering_platforms=exp_primal.lowering_platforms, + disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals) + + return Exported( + fun_name=lowered._fun_name, + in_tree=lowered.in_tree, + out_tree=lowered.out_tree, + in_avals=tuple(args_avals_flat), + out_avals=tuple(out_avals_flat), + in_shardings=in_shardings, + out_shardings=out_shardings, + nr_devices=nr_devices, + lowering_platforms=lowering._platforms, + ordered_effects=ordered_effects, + unordered_effects=unordered_effects, + disabled_safety_checks=tuple(disabled_checks), + mlir_module_serialized=mlir_module_serialized, + module_kept_var_idx=module_kept_var_idx, + uses_shape_polymorphism=shape_poly_state.uses_dim_vars, + mlir_module_serialization_version=version, + _get_vjp=_get_exported_vjp) def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) @@ -713,7 +737,7 @@ def _check_lowering(lowering) -> None: # safe to add it to the allowed_compile_args if it does not change the semantics # or the calling convention of the lowered module. allowed_compile_args = [ - "backend", "mesh", "global_in_avals", + "backend", "platforms", "mesh", "global_in_avals", "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", "mut", "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", @@ -918,12 +942,15 @@ def _get_vjp_fun(primal_fun: Callable, *, in_shardings: tuple[Sharding, ...], out_shardings: tuple[Sharding, ...], device_assignment: Sequence[sharding_impls.Device] | None, - apply_jit: bool + apply_jit: bool, + flat_primal_fun: bool = False, ) -> tuple[Callable, Sequence[core.AbstractValue]]: # Since jax.vjp does not handle kwargs, it is easier to do all the work # here with flattened functions. # apply_jit=False is only used for backwards compatibility with the graph # graph serialization. When apply_jit=True, we must pass a device assignment. + # flat_primal_fun=False is used only from jax2tf, and it means that the + # `primal_fun` takes PyTree `*args` and `**kwargs`. def fun_vjp_jax(*args_and_out_cts_flat_jax): # Takes a flat list of primals and output cotangents def flattened_primal_fun_jax(*args_flat): @@ -934,7 +961,8 @@ def _get_vjp_fun(primal_fun: Callable, *, args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax, [len(in_avals)]) - _, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax) + _, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax, + *args_flat_jax) return pullback_jax(out_cts_flat_jax) vjp_in_avals = list( @@ -953,22 +981,6 @@ def _get_vjp_fun(primal_fun: Callable, *, else: return fun_vjp_jax, vjp_in_avals -def _export_native_vjp(primal_fun, - primal: Exported, - device_assignment: Sequence[sharding_impls.Device]) -> Exported: - # Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp - fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun, - in_tree=primal.in_tree, - in_avals=primal.in_avals, - in_shardings=primal.in_shardings, - out_avals=primal.out_avals, - out_shardings=primal.out_shardings, - device_assignment=device_assignment, - apply_jit=True) - return export(fun_vjp_jax, - lowering_platforms=primal.lowering_platforms, - disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals) - ### Calling the exported function def call(exported: Exported) -> Callable[..., jax.Array]: