diff --git a/CHANGELOG.md b/CHANGELOG.md index e310b296b..93745d9d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`, {obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`. In future releases we plan to expand these to other ufuncs. + * Added {func}`jax.lax.optimization_barrier`, which allows users to prevent + compiler optimizations such as common-subexpression elimination and to + control scheduling. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 7b19955d3..e0fc5ad46 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -119,6 +119,7 @@ Operators ne neg nextafter + optimization_barrier pad platform_dependent polygamma diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6bf481ef0..fd3011988 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -27,7 +27,6 @@ from jax._src import ad_util from jax._src import api from jax._src import config from jax._src import core -from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects @@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): - args = _optimization_barrier(args) + args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) # TODO(mattjj): add core utility for 'create dummy value for this type'? @@ -837,27 +836,6 @@ mlir.register_lowering(remat_p, _remat_lowering) mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), platform="gpu") -def _optimization_barrier_abstract_eval(*args): - return args - -def _optimization_barrier_lowering_rule(ctx, *args): - barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) - flat_args = mlir.flatten_ir_values(args) - barrier_op = hlo.OptimizationBarrierOp(flat_args) - return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) - -def _optimization_barrier(arg): - flat_args, treedef = tree_flatten(arg) - return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) - -optimization_barrier_p = core.Primitive('optimization_barrier') -optimization_barrier_p.multiple_results = True -optimization_barrier_p.def_impl( - partial(dispatch.apply_primitive, optimization_barrier_p)) -optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) -mlir.register_lowering(optimization_barrier_p, - _optimization_barrier_lowering_rule) - def checkpoint_name(x, name): return name_p.bind(x, name=name) @@ -936,3 +914,6 @@ def checkpoint_wrapper( raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) + +# TODO(phawkins): update users to refer to the public name. +_optimization_barrier = lax_internal.optimization_barrier diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 05dcade84..5e6fa86f7 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -33,4 +33,4 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr, _initial_style_jaxprs_with_common_consts, _check_tree_and_avals) # TODO(mattjj): fix dependent library which expects optimization_barrier_p here -from jax._src.ad_checkpoint import optimization_barrier_p +from jax._src.lax.lax import optimization_barrier_p diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 2186e767a..8d2c24d6e 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5346,3 +5346,59 @@ class BIntRules: core.bint._rules = BIntRules + + +def optimization_barrier(operand, /): + """Prevents the compiler from moving operations across the barrier. + + Optimization barriers have a number of possible uses: + + * An optimization barrier ensures that all inputs are evaluated before any + operators that depend on the barrier's outputs. This can be used to enforce + a particular order of operations. + * An optimization barrier prevents common subexpression elimination. This is + used by JAX to implement rematerialization. + * Optimization barriers prevent compiler fusions. That is, operations before + the barrier may not be fused into the same kernel as operations after the + barrier by the compiler. + + JAX does not define derivative or batching rules for an optimization barrier. + + Optimization barriers have no effect outside a compiled function. + + Args: + operand: a pytree of JAX values. + + Returns: + A pytree of JAX values, with the same structure and contents as ``operand``. + + Examples: + Prevents common-subexpression elimination between the two calls to `sin`: + + >>> def f(x): + ... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x) + >>> jax.jit(f)(0.) + Array(0., dtype=float32, weak_type=True) + """ + flat_args, treedef = tree_util.tree_flatten(operand) + return tree_util.tree_unflatten( + treedef, optimization_barrier_p.bind(*flat_args)) + + +def _optimization_barrier_abstract_eval(*args): + return args + +def _optimization_barrier_lowering_rule(ctx, *args): + barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) + flat_args = mlir.flatten_ir_values(args) + barrier_op = hlo.OptimizationBarrierOp(flat_args) + return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) + + +optimization_barrier_p = core.Primitive('optimization_barrier') +optimization_barrier_p.multiple_results = True +optimization_barrier_p.def_impl( + partial(dispatch.apply_primitive, optimization_barrier_p)) +optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) +mlir.register_lowering(optimization_barrier_p, + _optimization_barrier_lowering_rule) diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index bac005b81..e2bcd5de9 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -142,6 +142,8 @@ from jax._src.lax.lax import ( nextafter as nextafter, nextafter_p as nextafter_p, not_p as not_p, + optimization_barrier as optimization_barrier, + optimization_barrier_p as optimization_barrier_p, or_p as or_p, outfeed as outfeed, outfeed_p as outfeed_p, diff --git a/tests/lax_test.py b/tests/lax_test.py index 7ed17adf4..ce3013195 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3093,6 +3093,10 @@ class LaxTest(jtu.JaxTestCase): with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() + def testOptimizationBarrier(self): + x = lax.optimization_barrier((2, 3)) + self.assertEqual((2, 3), x) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected):