mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add a direct HLO lowering of remat_p that doesn't call eval_jaxpr.
This turns out to be faster, not least because we don't need to use the tracing machinery again. PiperOrigin-RevId: 647462045
This commit is contained in:
parent
af8bdd1582
commit
61703690ee
@ -40,6 +40,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
@ -695,9 +696,9 @@ def _has_effects(effects) -> bool:
|
||||
return bool({e for e in effects if not isinstance(e, core.NamedAxisEffect)})
|
||||
|
||||
|
||||
def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
differentiated: bool, is_gpu_platform: bool = False,
|
||||
**_):
|
||||
def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
differentiated: bool, is_gpu_platform: bool = False,
|
||||
**_):
|
||||
assert not jaxpr.constvars
|
||||
|
||||
if differentiated and prevent_cse:
|
||||
@ -766,13 +767,33 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
unif = lax_internal.rng_uniform(np.float32(0), np.float32(1), shape=())
|
||||
return lax_control_flow.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)
|
||||
|
||||
mlir.register_lowering(
|
||||
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
|
||||
mlir.register_lowering(
|
||||
remat_p,
|
||||
mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True),
|
||||
multiple_results=True),
|
||||
platform="gpu")
|
||||
def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
differentiated: bool, policy, is_gpu_platform=False):
|
||||
jaxpr_args: Sequence[Sequence[ir.Value]]
|
||||
if differentiated and prevent_cse:
|
||||
# If we're using the loop or cond lowerings, use the slower lower_fun
|
||||
# based path.
|
||||
if not config.remat_opt_barrier.value:
|
||||
return mlir.lower_fun(remat_expansion, multiple_results=True)(
|
||||
ctx, *args, jaxpr=jaxpr, prevent_cse=prevent_cse,
|
||||
differentiated=differentiated, policy=policy,
|
||||
is_gpu_platform=is_gpu_platform)
|
||||
|
||||
arg_types = map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
jaxpr_args = util.unflatten(barrier_op.results, map(len, arg_types))
|
||||
else:
|
||||
jaxpr_args = map(mlir.wrap_singleton_ir_values, args)
|
||||
outs, tokens_out = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
|
||||
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens_out)
|
||||
return outs
|
||||
|
||||
mlir.register_lowering(remat_p, _remat_lowering)
|
||||
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
|
||||
platform="gpu")
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
@ -3132,7 +3132,7 @@ tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(
|
||||
extra_name_stack="scan")
|
||||
|
||||
tf_impl_with_avals[ad_checkpoint.remat_p] = \
|
||||
_convert_jax_impl(partial(ad_checkpoint.remat_lowering,
|
||||
_convert_jax_impl(partial(ad_checkpoint.remat_expansion,
|
||||
# TODO: jax2tf cannot discriminate by platform
|
||||
is_gpu_platform=False),
|
||||
multiple_results=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user