mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add the function name, the Jaxpr, and lowering platforms to Lowered.
These changes are necessary to ensure that `Lowered` carries all the information that is needed for export and serialization. These are in preparation of a cleanup of the exporting and serialization APIs to integrate them with the AOT APIs. In particular, exporting will start with a `Lowered` object and will not include anymore its own lowering code. We add the lowered function name and the Jaxpr (as the attributes `_fun_name` and `_jaxpr`) to `Lowered`, and we add the tuple of lowering platforms (as `Lowered._lowering._platforms`). The function name is useful for better error messages when exporting and serializating. The Jaxpr is useful for exporting also the VJP of the function and obtaining an `Exported` that can be differentiated.
This commit is contained in:
parent
4fae9aa160
commit
bb4c073574
@ -1846,7 +1846,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple,
|
||||
devices, backend, axis_size, args, kwargs)
|
||||
abstract_args = list(map(shaped_abstractify, p.flat_args))
|
||||
computation = pxla.lower_parallel_callable(
|
||||
computation, closed_jaxpr = pxla.lower_parallel_callable(
|
||||
p.flat_fun, backend, axis_name,
|
||||
axis_size=p.local_axis_size, global_axis_size=p.global_axis_size,
|
||||
devices=p.devices,
|
||||
@ -1858,7 +1858,8 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
avals=abstract_args,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree())
|
||||
computation, p.in_tree, abstract_args, donate_tuple, p.out_tree(),
|
||||
fun_name=p.flat_fun.__name__, jaxpr=closed_jaxpr)
|
||||
|
||||
return lower
|
||||
|
||||
|
@ -556,7 +556,7 @@ def parallel_callable(fun: lu.WrappedFun,
|
||||
donated_invars: Sequence[bool],
|
||||
is_explicit_global_axis_size: bool,
|
||||
*avals):
|
||||
pmap_computation = lower_parallel_callable(
|
||||
pmap_computation, _ = lower_parallel_callable(
|
||||
fun, backend_name, axis_name, axis_size, global_axis_size, devices, name,
|
||||
in_axes, out_axes_thunk, donated_invars,
|
||||
is_explicit_global_axis_size, avals,
|
||||
@ -679,7 +679,7 @@ def lower_parallel_callable(
|
||||
is_explicit_global_axis_size: bool,
|
||||
avals: Sequence[core.AbstractValue],
|
||||
*,
|
||||
lowering_parameters: mlir.LoweringParameters) -> PmapComputation:
|
||||
lowering_parameters: mlir.LoweringParameters) -> tuple[PmapComputation, core.ClosedJaxpr]:
|
||||
# Determine global_axis_size for use in AxisEnv.
|
||||
# TODO(mattjj,skyewm): revive this check (inner_pmap always False now)
|
||||
# if xb.process_count() > 1 and global_axis_size is None and inner_pmap:
|
||||
@ -761,6 +761,7 @@ def lower_parallel_callable(
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
backend.platform)
|
||||
module_name = f"pmap_{fun.__name__}"
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None):
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
@ -776,7 +777,7 @@ def lower_parallel_callable(
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
platforms=platforms,
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
@ -787,14 +788,16 @@ def lower_parallel_callable(
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
lowering_parameters=lowering_parameters)
|
||||
return PmapComputation(lowering_result.module, pci=pci, replicas=replicas,
|
||||
return PmapComputation(lowering_result.module,
|
||||
platforms=platforms,
|
||||
pci=pci, replicas=replicas,
|
||||
shards=shards, tuple_args=tuple_args,
|
||||
unordered_effects=unordered_effects,
|
||||
ordered_effects=ordered_effects,
|
||||
keepalive=lowering_result.keepalive,
|
||||
host_callbacks=lowering_result.host_callbacks,
|
||||
jaxpr_debug_info=closed_jaxpr.jaxpr.debug_info,
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
shape_poly_state=lowering_result.shape_poly_state), closed_jaxpr
|
||||
|
||||
|
||||
def _pmap_unmap_shaped_array(
|
||||
@ -907,10 +910,13 @@ class UnloadedPmapExecutable:
|
||||
host_callbacks: list[Any],
|
||||
keepalive: Any,
|
||||
jaxpr_debug_info: core.JaxprDebugInfo,
|
||||
platforms: Sequence[str],
|
||||
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
|
||||
compiler_options=None):
|
||||
del platforms
|
||||
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
|
||||
hlo = mlir.refine_polymorphic_shapes(hlo)
|
||||
|
||||
devices = pci.devices
|
||||
if devices is None:
|
||||
if shards.num_global_shards > xb.device_count(pci.backend):
|
||||
@ -1941,7 +1947,6 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
"The following ordered effects are not supported for "
|
||||
f"more than 1 device: {unsupported_effects}")
|
||||
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
|
||||
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
|
||||
@ -2141,6 +2146,7 @@ def lower_sharding_computation(
|
||||
for js, source_info in util.stable_unique(jaxpr_sharding))),
|
||||
devices_from_context)
|
||||
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
# TODO(yashkatariya): Enable this when offload APIs are stable.
|
||||
# transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr))
|
||||
|
||||
@ -2204,6 +2210,7 @@ def lower_sharding_computation(
|
||||
kept_var_idx=kept_var_idx,
|
||||
mut=mut,
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
device_assignment=da_object,
|
||||
committed=committed,
|
||||
in_layouts=in_layouts,
|
||||
@ -2244,6 +2251,7 @@ def lower_mesh_computation(
|
||||
lowering_parameters: mlir.LoweringParameters) -> MeshComputation:
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
platforms = lowering_parameters.platforms or (backend.platform,)
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
global_axis_sizes = mesh.shape
|
||||
@ -2352,7 +2360,7 @@ def lower_mesh_computation(
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platforms=lowering_parameters.platforms or (backend.platform,),
|
||||
platforms=platforms,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
@ -2382,6 +2390,7 @@ def lower_mesh_computation(
|
||||
keepalive=lowering_result.keepalive,
|
||||
kept_var_idx=set(range(len(global_in_avals))),
|
||||
backend=backend,
|
||||
platforms=platforms,
|
||||
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
|
||||
committed=True,
|
||||
in_layouts=(None,) * len(global_in_avals),
|
||||
@ -2394,10 +2403,14 @@ class MeshComputation(stages.XlaLowering):
|
||||
_executable: MeshExecutable | None
|
||||
|
||||
def __init__(self, name: str, hlo: ir.Module,
|
||||
donated_invars: Sequence[bool], **compile_args):
|
||||
donated_invars: Sequence[bool],
|
||||
platforms: Sequence[str] | None = None, # None only for backwards
|
||||
# compatibility with PartIR
|
||||
**compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
self._donated_invars = donated_invars
|
||||
self._platforms = platforms
|
||||
self.compile_args = compile_args
|
||||
self._executable = None
|
||||
|
||||
|
@ -617,7 +617,7 @@ def xmap(fun: Callable,
|
||||
'_experimental_lowering_platform', mlir.LoweringParameters())
|
||||
fun_flat, args_flat, params, in_tree, out_tree = infer_params(*args)
|
||||
avals_flat = [shaped_abstractify(arg) for arg in args_flat]
|
||||
computation = make_xmap_callable(
|
||||
computation, jaxpr = make_xmap_callable(
|
||||
fun_flat, params['name'], params['in_axes'], params['out_axes_thunk'],
|
||||
params['donated_invars'], params['global_axis_sizes'], params['axis_resources'],
|
||||
params['resource_env'], params['backend'], params['spmd_in_axes'],
|
||||
@ -628,7 +628,7 @@ def xmap(fun: Callable,
|
||||
in_avals = in_tree.unflatten(avals_flat)
|
||||
return stages.Lowered.from_flat_info(
|
||||
computation, in_tree, in_avals, donate_argnums, out_tree(),
|
||||
no_kwargs=True)
|
||||
no_kwargs=True, fun_name=params['name'], jaxpr=jaxpr)
|
||||
|
||||
fun_mapped.lower = lower
|
||||
return type_cast(stages.Wrapped, fun_mapped)
|
||||
@ -637,11 +637,12 @@ def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, donated_
|
||||
global_axis_sizes, axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(arg)) for arg in args]
|
||||
xmap_callable = make_xmap_callable(
|
||||
computation, _ = make_xmap_callable(
|
||||
fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes,
|
||||
axis_resources, resource_env, backend,
|
||||
spmd_in_axes, spmd_out_axes_thunk,
|
||||
mlir.LoweringParameters(), *in_avals).compile().unsafe_call
|
||||
mlir.LoweringParameters(), *in_avals)
|
||||
xmap_callable = computation.compile().unsafe_call
|
||||
distributed_debug_log(("Running xmapped function", name),
|
||||
("python function", fun.f),
|
||||
("mesh", resource_env.physical_mesh),
|
||||
@ -708,7 +709,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
in_shardings, out_shardings, donated_invars,
|
||||
use_spmd_lowering, in_avals,
|
||||
tiling_method=tiling_method,
|
||||
lowering_parameters=lowering_parameters)
|
||||
lowering_parameters=lowering_parameters), jaxpr
|
||||
else:
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
|
||||
return pxla.lower_sharding_computation(
|
||||
@ -716,7 +717,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
|
||||
(None,) * len(in_avals), (None,) * len(out_avals),
|
||||
donated_invars, keep_unused=True, inline=False,
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters)
|
||||
devices_from_context=None, lowering_parameters=lowering_parameters), jaxpr
|
||||
|
||||
|
||||
class EvaluationPlan(NamedTuple):
|
||||
|
@ -469,7 +469,7 @@ def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
donate_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering, in_tree, flat_global_in_avals, donate_argnums,
|
||||
out_tree)
|
||||
out_tree, fun_name=params["name"], jaxpr=params["jaxpr"])
|
||||
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
|
@ -601,23 +601,29 @@ class Lowered(Stage):
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.).
|
||||
"""
|
||||
__slots__ = ["args_info", "out_tree", "_lowering", "_no_kwargs"]
|
||||
|
||||
__slots__ = ["_lowering", "args_info", "out_tree", "_no_kwargs", "_fun_name", "_jaxpr"]
|
||||
_lowering: XlaLowering
|
||||
args_info: Any # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef
|
||||
_lowering: XlaLowering
|
||||
_no_kwargs: bool
|
||||
_fun_name: str
|
||||
_jaxpr: core.ClosedJaxpr | None # Can be None when this class is constructed
|
||||
# outside of JAX core.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lowering: XlaLowering,
|
||||
args_info, # PyTree of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False):
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "unknown",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
self._lowering = lowering
|
||||
self._no_kwargs = no_kwargs
|
||||
self.args_info = args_info
|
||||
self.out_tree = out_tree
|
||||
self._fun_name = fun_name
|
||||
self._jaxpr = jaxpr
|
||||
|
||||
@classmethod
|
||||
def from_flat_info(cls,
|
||||
@ -626,7 +632,9 @@ class Lowered(Stage):
|
||||
in_avals,
|
||||
donate_argnums: tuple[int, ...],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False):
|
||||
no_kwargs: bool = False,
|
||||
fun_name: str = "unknown",
|
||||
jaxpr: core.ClosedJaxpr | None = None):
|
||||
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||||
|
||||
Args:
|
||||
@ -635,12 +643,14 @@ class Lowered(Stage):
|
||||
no_kwargs: If ``True`` the transformation, and the
|
||||
``Compiled`` returned from this object will not support keyword
|
||||
arguments (an error will be raised if some are provided).
|
||||
fun_name: the name of the lowered function, if available.
|
||||
jaxpr: the Jaxpr of the lowered function, if available.
|
||||
"""
|
||||
return cls(
|
||||
lowering,
|
||||
make_args_info(in_tree, in_avals, donate_argnums),
|
||||
out_tree,
|
||||
no_kwargs=no_kwargs)
|
||||
no_kwargs=no_kwargs, fun_name=fun_name, jaxpr=jaxpr)
|
||||
|
||||
def compile(
|
||||
self, compiler_options: CompilerOptions | None = None) -> Compiled:
|
||||
|
@ -47,6 +47,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src import pjit
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -374,14 +375,6 @@ def export(fun_jax: Callable,
|
||||
def f_jax(*args, **kwargs): ...
|
||||
exported = jax_export.export(f_jax)(*args, **kwargs)
|
||||
"""
|
||||
fun_name = getattr(fun_jax, "__name__", "unknown")
|
||||
version = config.jax_serialization_version.value
|
||||
if (version < minimum_supported_serialization_version or
|
||||
version > maximum_supported_serialization_version):
|
||||
raise ValueError(
|
||||
f"The requested jax_serialization version {version} is outside the "
|
||||
f"range of supported versions [{minimum_supported_serialization_version}"
|
||||
f"..{maximum_supported_serialization_version}]")
|
||||
|
||||
def do_export(*args_specs, **kwargs_specs) -> Exported:
|
||||
if not hasattr(fun_jax, "lower"):
|
||||
@ -402,7 +395,7 @@ def export(fun_jax: Callable,
|
||||
# TODO: move to `lower`
|
||||
symbolic_scope: tuple[_shape_poly.SymbolicScope, tree_util.KeyPath] | None = None
|
||||
for k_path, aval in tree_util.tree_flatten_with_path((args_specs, kwargs_specs))[0]:
|
||||
# Static args may has no `shape` attribute.
|
||||
# Static args may have no `shape` attribute.
|
||||
if not hasattr(aval, "shape"):
|
||||
continue
|
||||
for d in aval.shape:
|
||||
@ -411,7 +404,7 @@ def export(fun_jax: Callable,
|
||||
symbolic_scope = (d.scope, k_path)
|
||||
continue
|
||||
symbolic_scope[0]._check_same_scope(
|
||||
d, when=f"when exporting {fun_name}",
|
||||
d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}",
|
||||
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
|
||||
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))
|
||||
|
||||
@ -420,6 +413,20 @@ def export(fun_jax: Callable,
|
||||
_experimental_lowering_parameters=mlir.LoweringParameters(
|
||||
platforms=actual_lowering_platforms,
|
||||
))
|
||||
return _export_lowered(lowered, disabled_checks=disabled_checks)
|
||||
return do_export
|
||||
|
||||
def _export_lowered(
|
||||
lowered: stages.Lowered,
|
||||
disabled_checks: Sequence[DisabledSafetyCheck] = (),
|
||||
) -> Exported:
|
||||
version = config.jax_serialization_version.value
|
||||
if (version < minimum_supported_serialization_version or
|
||||
version > maximum_supported_serialization_version):
|
||||
raise ValueError(
|
||||
f"The requested jax_serialization version {version} is outside the "
|
||||
f"range of supported versions [{minimum_supported_serialization_version}"
|
||||
f"..{maximum_supported_serialization_version}]")
|
||||
|
||||
lowering = lowered._lowering
|
||||
_check_lowering(lowering)
|
||||
@ -461,7 +468,7 @@ def export(fun_jax: Callable,
|
||||
# Log and then check the module.
|
||||
if logging.vlog_is_on(3):
|
||||
logmsg = (f"version={version} "
|
||||
f"lowering_platforms={actual_lowering_platforms} "
|
||||
f"lowering_platforms={lowering.compile_args['platforms']} "
|
||||
f"disabled_checks={disabled_checks}")
|
||||
logging.info("Lowered JAX module: %s\n", logmsg)
|
||||
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
|
||||
@ -489,8 +496,29 @@ def export(fun_jax: Callable,
|
||||
out_shardings = tuple(
|
||||
export_sharding(s, aval)
|
||||
for s, aval in zip(lowering.compile_args["out_shardings"], out_avals_flat))
|
||||
|
||||
device_assignment = lowering.compile_args["device_assignment"]
|
||||
def _get_exported_vjp(exp_primal: Exported) -> Exported:
|
||||
# Turn the primal jaxpr into a function, in preparation for exporting
|
||||
# the VJP. Note that jaxpr_as_fun produces a function with flat arguments
|
||||
assert(lowered._jaxpr is not None) # None only when the lowered was created outside JAX
|
||||
fun_jax = core.jaxpr_as_fun(lowered._jaxpr)
|
||||
|
||||
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(fun_jax,
|
||||
in_tree=exp_primal.in_tree,
|
||||
in_avals=exp_primal.in_avals,
|
||||
in_shardings=exp_primal.in_shardings,
|
||||
out_avals=exp_primal.out_avals,
|
||||
out_shardings=exp_primal.out_shardings,
|
||||
device_assignment=device_assignment,
|
||||
apply_jit=True,
|
||||
flat_primal_fun=True)
|
||||
return export(fun_vjp_jax,
|
||||
lowering_platforms=exp_primal.lowering_platforms,
|
||||
disabled_checks=exp_primal.disabled_safety_checks)(*vjp_in_avals)
|
||||
|
||||
return Exported(
|
||||
fun_name=fun_name,
|
||||
fun_name=lowered._fun_name,
|
||||
in_tree=lowered.in_tree,
|
||||
out_tree=lowered.out_tree,
|
||||
in_avals=tuple(args_avals_flat),
|
||||
@ -498,7 +526,7 @@ def export(fun_jax: Callable,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings,
|
||||
nr_devices=nr_devices,
|
||||
lowering_platforms=actual_lowering_platforms,
|
||||
lowering_platforms=lowering._platforms,
|
||||
ordered_effects=ordered_effects,
|
||||
unordered_effects=unordered_effects,
|
||||
disabled_safety_checks=tuple(disabled_checks),
|
||||
@ -506,11 +534,7 @@ def export(fun_jax: Callable,
|
||||
module_kept_var_idx=module_kept_var_idx,
|
||||
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
|
||||
mlir_module_serialization_version=version,
|
||||
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported,
|
||||
lowering.compile_args["device_assignment"]))
|
||||
|
||||
return do_export
|
||||
|
||||
_get_vjp=_get_exported_vjp)
|
||||
|
||||
def _module_to_bytecode(module: ir.Module) -> bytes:
|
||||
mlir_str = mlir.module_to_bytecode(module)
|
||||
@ -713,7 +737,7 @@ def _check_lowering(lowering) -> None:
|
||||
# safe to add it to the allowed_compile_args if it does not change the semantics
|
||||
# or the calling convention of the lowered module.
|
||||
allowed_compile_args = [
|
||||
"backend", "mesh", "global_in_avals",
|
||||
"backend", "platforms", "mesh", "global_in_avals",
|
||||
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
|
||||
"mut", "spmd_lowering", "auto_spmd_lowering",
|
||||
"tuple_args", "ordered_effects", "unordered_effects",
|
||||
@ -918,12 +942,15 @@ def _get_vjp_fun(primal_fun: Callable, *,
|
||||
in_shardings: tuple[Sharding, ...],
|
||||
out_shardings: tuple[Sharding, ...],
|
||||
device_assignment: Sequence[sharding_impls.Device] | None,
|
||||
apply_jit: bool
|
||||
apply_jit: bool,
|
||||
flat_primal_fun: bool = False,
|
||||
) -> tuple[Callable, Sequence[core.AbstractValue]]:
|
||||
# Since jax.vjp does not handle kwargs, it is easier to do all the work
|
||||
# here with flattened functions.
|
||||
# apply_jit=False is only used for backwards compatibility with the graph
|
||||
# graph serialization. When apply_jit=True, we must pass a device assignment.
|
||||
# flat_primal_fun=False is used only from jax2tf, and it means that the
|
||||
# `primal_fun` takes PyTree `*args` and `**kwargs`.
|
||||
def fun_vjp_jax(*args_and_out_cts_flat_jax):
|
||||
# Takes a flat list of primals and output cotangents
|
||||
def flattened_primal_fun_jax(*args_flat):
|
||||
@ -934,7 +961,8 @@ def _get_vjp_fun(primal_fun: Callable, *,
|
||||
|
||||
args_flat_jax, out_cts_flat_jax = util.split_list(args_and_out_cts_flat_jax,
|
||||
[len(in_avals)])
|
||||
_, pullback_jax = jax.vjp(flattened_primal_fun_jax, *args_flat_jax)
|
||||
_, pullback_jax = jax.vjp(primal_fun if flat_primal_fun else flattened_primal_fun_jax,
|
||||
*args_flat_jax)
|
||||
return pullback_jax(out_cts_flat_jax)
|
||||
|
||||
vjp_in_avals = list(
|
||||
@ -953,22 +981,6 @@ def _get_vjp_fun(primal_fun: Callable, *,
|
||||
else:
|
||||
return fun_vjp_jax, vjp_in_avals
|
||||
|
||||
def _export_native_vjp(primal_fun,
|
||||
primal: Exported,
|
||||
device_assignment: Sequence[sharding_impls.Device]) -> Exported:
|
||||
# Export the VJP of `primal_fun_jax`. See documentation for Exported.vjp
|
||||
fun_vjp_jax, vjp_in_avals = _get_vjp_fun(primal_fun,
|
||||
in_tree=primal.in_tree,
|
||||
in_avals=primal.in_avals,
|
||||
in_shardings=primal.in_shardings,
|
||||
out_avals=primal.out_avals,
|
||||
out_shardings=primal.out_shardings,
|
||||
device_assignment=device_assignment,
|
||||
apply_jit=True)
|
||||
return export(fun_vjp_jax,
|
||||
lowering_platforms=primal.lowering_platforms,
|
||||
disabled_checks=primal.disabled_safety_checks)(*vjp_in_avals)
|
||||
|
||||
### Calling the exported function
|
||||
|
||||
def call(exported: Exported) -> Callable[..., jax.Array]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user