Make optimization_barrier a public lax API.

This commit is contained in:
Peter Hawkins 2024-09-05 19:49:12 +00:00
parent 65b1b0bd95
commit 9c86fdec02
7 changed files with 71 additions and 24 deletions

View File

@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
In future releases we plan to expand these to other ufuncs.
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
compiler optimizations such as common-subexpression elimination and to
control scheduling.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -119,6 +119,7 @@ Operators
ne
neg
nextafter
optimization_barrier
pad
platform_dependent
polygamma

View File

@ -27,7 +27,6 @@ from jax._src import ad_util
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import effects
@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
args = lax_internal.optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)
# TODO(mattjj): add core utility for 'create dummy value for this type'?
@ -837,27 +836,6 @@ mlir.register_lowering(remat_p, _remat_lowering)
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
platform="gpu")
def _optimization_barrier_abstract_eval(*args):
return args
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(dispatch.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)
def checkpoint_name(x, name):
return name_p.bind(x, name=name)
@ -936,3 +914,6 @@ def checkpoint_wrapper(
raise NotImplementedError(msg)
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
# TODO(phawkins): update users to refer to the public name.
_optimization_barrier = lax_internal.optimization_barrier

View File

@ -33,4 +33,4 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.ad_checkpoint import optimization_barrier_p
from jax._src.lax.lax import optimization_barrier_p

View File

@ -5346,3 +5346,59 @@ class BIntRules:
core.bint._rules = BIntRules
def optimization_barrier(operand, /):
"""Prevents the compiler from moving operations across the barrier.
Optimization barriers have a number of possible uses:
* An optimization barrier ensures that all inputs are evaluated before any
operators that depend on the barrier's outputs. This can be used to enforce
a particular order of operations.
* An optimization barrier prevents common subexpression elimination. This is
used by JAX to implement rematerialization.
* Optimization barriers prevent compiler fusions. That is, operations before
the barrier may not be fused into the same kernel as operations after the
barrier by the compiler.
JAX does not define derivative or batching rules for an optimization barrier.
Optimization barriers have no effect outside a compiled function.
Args:
operand: a pytree of JAX values.
Returns:
A pytree of JAX values, with the same structure and contents as ``operand``.
Examples:
Prevents common-subexpression elimination between the two calls to `sin`:
>>> def f(x):
... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
>>> jax.jit(f)(0.)
Array(0., dtype=float32, weak_type=True)
"""
flat_args, treedef = tree_util.tree_flatten(operand)
return tree_util.tree_unflatten(
treedef, optimization_barrier_p.bind(*flat_args))
def _optimization_barrier_abstract_eval(*args):
return args
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(dispatch.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

View File

@ -142,6 +142,8 @@ from jax._src.lax.lax import (
nextafter as nextafter,
nextafter_p as nextafter_p,
not_p as not_p,
optimization_barrier as optimization_barrier,
optimization_barrier_p as optimization_barrier_p,
or_p as or_p,
outfeed as outfeed,
outfeed_p as outfeed_p,

View File

@ -3093,6 +3093,10 @@ class LaxTest(jtu.JaxTestCase):
with jax.transfer_guard('disallow'):
jax.jit(asarray_closure)()
def testOptimizationBarrier(self):
x = lax.optimization_barrier((2, 3))
self.assertEqual((2, 3), x)
class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):