mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make optimization_barrier a public lax API.
This commit is contained in:
parent
65b1b0bd95
commit
9c86fdec02
@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
|
||||
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
|
||||
In future releases we plan to expand these to other ufuncs.
|
||||
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
|
||||
compiler optimizations such as common-subexpression elimination and to
|
||||
control scheduling.
|
||||
|
||||
* Breaking changes
|
||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||
|
@ -119,6 +119,7 @@ Operators
|
||||
ne
|
||||
neg
|
||||
nextafter
|
||||
optimization_barrier
|
||||
pad
|
||||
platform_dependent
|
||||
polygamma
|
||||
|
@ -27,7 +27,6 @@ from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import effects
|
||||
@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)
|
||||
|
||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||
args = _optimization_barrier(args)
|
||||
args = lax_internal.optimization_barrier(args)
|
||||
return core.eval_jaxpr(jaxpr, (), *args)
|
||||
|
||||
# TODO(mattjj): add core utility for 'create dummy value for this type'?
|
||||
@ -837,27 +836,6 @@ mlir.register_lowering(remat_p, _remat_lowering)
|
||||
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
|
||||
platform="gpu")
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
||||
flat_args = mlir.flatten_ir_values(args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
||||
|
||||
def _optimization_barrier(arg):
|
||||
flat_args, treedef = tree_flatten(arg)
|
||||
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
||||
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(dispatch.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
||||
|
||||
|
||||
def checkpoint_name(x, name):
|
||||
return name_p.bind(x, name=name)
|
||||
@ -936,3 +914,6 @@ def checkpoint_wrapper(
|
||||
raise NotImplementedError(msg)
|
||||
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||
static_argnums=static_argnums)
|
||||
|
||||
# TODO(phawkins): update users to refer to the public name.
|
||||
_optimization_barrier = lax_internal.optimization_barrier
|
||||
|
@ -33,4 +33,4 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
|
||||
_initial_style_jaxprs_with_common_consts,
|
||||
_check_tree_and_avals)
|
||||
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
|
||||
from jax._src.ad_checkpoint import optimization_barrier_p
|
||||
from jax._src.lax.lax import optimization_barrier_p
|
||||
|
@ -5346,3 +5346,59 @@ class BIntRules:
|
||||
|
||||
|
||||
core.bint._rules = BIntRules
|
||||
|
||||
|
||||
def optimization_barrier(operand, /):
|
||||
"""Prevents the compiler from moving operations across the barrier.
|
||||
|
||||
Optimization barriers have a number of possible uses:
|
||||
|
||||
* An optimization barrier ensures that all inputs are evaluated before any
|
||||
operators that depend on the barrier's outputs. This can be used to enforce
|
||||
a particular order of operations.
|
||||
* An optimization barrier prevents common subexpression elimination. This is
|
||||
used by JAX to implement rematerialization.
|
||||
* Optimization barriers prevent compiler fusions. That is, operations before
|
||||
the barrier may not be fused into the same kernel as operations after the
|
||||
barrier by the compiler.
|
||||
|
||||
JAX does not define derivative or batching rules for an optimization barrier.
|
||||
|
||||
Optimization barriers have no effect outside a compiled function.
|
||||
|
||||
Args:
|
||||
operand: a pytree of JAX values.
|
||||
|
||||
Returns:
|
||||
A pytree of JAX values, with the same structure and contents as ``operand``.
|
||||
|
||||
Examples:
|
||||
Prevents common-subexpression elimination between the two calls to `sin`:
|
||||
|
||||
>>> def f(x):
|
||||
... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
|
||||
>>> jax.jit(f)(0.)
|
||||
Array(0., dtype=float32, weak_type=True)
|
||||
"""
|
||||
flat_args, treedef = tree_util.tree_flatten(operand)
|
||||
return tree_util.tree_unflatten(
|
||||
treedef, optimization_barrier_p.bind(*flat_args))
|
||||
|
||||
|
||||
def _optimization_barrier_abstract_eval(*args):
|
||||
return args
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
||||
flat_args = mlir.flatten_ir_values(args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
||||
|
||||
|
||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||
optimization_barrier_p.multiple_results = True
|
||||
optimization_barrier_p.def_impl(
|
||||
partial(dispatch.apply_primitive, optimization_barrier_p))
|
||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||
mlir.register_lowering(optimization_barrier_p,
|
||||
_optimization_barrier_lowering_rule)
|
||||
|
@ -142,6 +142,8 @@ from jax._src.lax.lax import (
|
||||
nextafter as nextafter,
|
||||
nextafter_p as nextafter_p,
|
||||
not_p as not_p,
|
||||
optimization_barrier as optimization_barrier,
|
||||
optimization_barrier_p as optimization_barrier_p,
|
||||
or_p as or_p,
|
||||
outfeed as outfeed,
|
||||
outfeed_p as outfeed_p,
|
||||
|
@ -3093,6 +3093,10 @@ class LaxTest(jtu.JaxTestCase):
|
||||
with jax.transfer_guard('disallow'):
|
||||
jax.jit(asarray_closure)()
|
||||
|
||||
def testOptimizationBarrier(self):
|
||||
x = lax.optimization_barrier((2, 3))
|
||||
self.assertEqual((2, 3), x)
|
||||
|
||||
|
||||
class LazyConstantTest(jtu.JaxTestCase):
|
||||
def _Check(self, make_const, expected):
|
||||
|
Loading…
x
Reference in New Issue
Block a user