Add a direct HLO lowering of remat_p that doesn't call eval_jaxpr.

This turns out to be faster, not least because we don't need to use the tracing machinery again.

PiperOrigin-RevId: 647462045
This commit is contained in:
Peter Hawkins 2024-06-27 15:14:20 -07:00 committed by jax authors
parent af8bdd1582
commit 61703690ee
2 changed files with 32 additions and 11 deletions

View File

@ -40,6 +40,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
@ -695,9 +696,9 @@ def _has_effects(effects) -> bool:
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, is_gpu_platform: bool = False,
**_):
def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, is_gpu_platform: bool = False,
**_):
assert not jaxpr.constvars
if differentiated and prevent_cse:
@ -766,13 +767,33 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
mlir.register_lowering(
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
mlir.register_lowering(
remat_p,
mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True),
multiple_results=True),
platform="gpu")
def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, policy, is_gpu_platform=False):
jaxpr_args: Sequence[Sequence[ir.Value]]
if differentiated and prevent_cse:
# If we're using the loop or cond lowerings, use the slower lower_fun
# based path.
if not config.remat_opt_barrier.value:
return mlir.lower_fun(remat_expansion, multiple_results=True)(
ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse,
differentiated=differentiated, policy=policy,
is_gpu_platform=is_gpu_platform)
arg_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
jaxpr_args = util.unflatten(barrier_op.results, map(len, arg_types))
else:
jaxpr_args = map(mlir.wrap_singleton_ir_values, args)
outs, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens_out)
return outs
mlir.register_lowering(remat_p, _remat_lowering)
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
platform="gpu")
def _optimization_barrier_abstract_eval(*args):
return args

View File

@ -3132,7 +3132,7 @@ tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
extra_name_stack="scan")
tf_impl_with_avals[ad_checkpoint.remat_p] = \
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
_convert_jax_impl(partial(ad_checkpoint.remat_expansion,
# TODO: jax2tf cannot discriminate by platform
is_gpu_platform=False),
multiple_results=True,