mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Make optimization_barrier a public lax API.
This commit is contained in:
parent
65b1b0bd95
commit
9c86fdec02
@ -33,6 +33,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
|||||||
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
|
{obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`,
|
||||||
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
|
{obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`.
|
||||||
In future releases we plan to expand these to other ufuncs.
|
In future releases we plan to expand these to other ufuncs.
|
||||||
|
* Added {func}`jax.lax.optimization_barrier`, which allows users to prevent
|
||||||
|
compiler optimizations such as common-subexpression elimination and to
|
||||||
|
control scheduling.
|
||||||
|
|
||||||
* Breaking changes
|
* Breaking changes
|
||||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||||
|
@ -119,6 +119,7 @@ Operators
|
|||||||
ne
|
ne
|
||||||
neg
|
neg
|
||||||
nextafter
|
nextafter
|
||||||
|
optimization_barrier
|
||||||
pad
|
pad
|
||||||
platform_dependent
|
platform_dependent
|
||||||
polygamma
|
polygamma
|
||||||
|
@ -27,7 +27,6 @@ from jax._src import ad_util
|
|||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import config
|
from jax._src import config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
|||||||
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)
|
return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr)
|
||||||
|
|
||||||
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
|
||||||
args = _optimization_barrier(args)
|
args = lax_internal.optimization_barrier(args)
|
||||||
return core.eval_jaxpr(jaxpr, (), *args)
|
return core.eval_jaxpr(jaxpr, (), *args)
|
||||||
|
|
||||||
# TODO(mattjj): add core utility for 'create dummy value for this type'?
|
# TODO(mattjj): add core utility for 'create dummy value for this type'?
|
||||||
@ -837,27 +836,6 @@ mlir.register_lowering(remat_p, _remat_lowering)
|
|||||||
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
|
mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True),
|
||||||
platform="gpu")
|
platform="gpu")
|
||||||
|
|
||||||
def _optimization_barrier_abstract_eval(*args):
|
|
||||||
return args
|
|
||||||
|
|
||||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
|
||||||
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
|
||||||
flat_args = mlir.flatten_ir_values(args)
|
|
||||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
|
||||||
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
|
||||||
|
|
||||||
def _optimization_barrier(arg):
|
|
||||||
flat_args, treedef = tree_flatten(arg)
|
|
||||||
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))
|
|
||||||
|
|
||||||
optimization_barrier_p = core.Primitive('optimization_barrier')
|
|
||||||
optimization_barrier_p.multiple_results = True
|
|
||||||
optimization_barrier_p.def_impl(
|
|
||||||
partial(dispatch.apply_primitive, optimization_barrier_p))
|
|
||||||
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
|
||||||
mlir.register_lowering(optimization_barrier_p,
|
|
||||||
_optimization_barrier_lowering_rule)
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_name(x, name):
|
def checkpoint_name(x, name):
|
||||||
return name_p.bind(x, name=name)
|
return name_p.bind(x, name=name)
|
||||||
@ -936,3 +914,6 @@ def checkpoint_wrapper(
|
|||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
return checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
|
||||||
static_argnums=static_argnums)
|
static_argnums=static_argnums)
|
||||||
|
|
||||||
|
# TODO(phawkins): update users to refer to the public name.
|
||||||
|
_optimization_barrier = lax_internal.optimization_barrier
|
||||||
|
@ -33,4 +33,4 @@ from jax._src.lax.control_flow.common import (_initial_style_open_jaxpr,
|
|||||||
_initial_style_jaxprs_with_common_consts,
|
_initial_style_jaxprs_with_common_consts,
|
||||||
_check_tree_and_avals)
|
_check_tree_and_avals)
|
||||||
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
|
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
|
||||||
from jax._src.ad_checkpoint import optimization_barrier_p
|
from jax._src.lax.lax import optimization_barrier_p
|
||||||
|
@ -5346,3 +5346,59 @@ class BIntRules:
|
|||||||
|
|
||||||
|
|
||||||
core.bint._rules = BIntRules
|
core.bint._rules = BIntRules
|
||||||
|
|
||||||
|
|
||||||
|
def optimization_barrier(operand, /):
|
||||||
|
"""Prevents the compiler from moving operations across the barrier.
|
||||||
|
|
||||||
|
Optimization barriers have a number of possible uses:
|
||||||
|
|
||||||
|
* An optimization barrier ensures that all inputs are evaluated before any
|
||||||
|
operators that depend on the barrier's outputs. This can be used to enforce
|
||||||
|
a particular order of operations.
|
||||||
|
* An optimization barrier prevents common subexpression elimination. This is
|
||||||
|
used by JAX to implement rematerialization.
|
||||||
|
* Optimization barriers prevent compiler fusions. That is, operations before
|
||||||
|
the barrier may not be fused into the same kernel as operations after the
|
||||||
|
barrier by the compiler.
|
||||||
|
|
||||||
|
JAX does not define derivative or batching rules for an optimization barrier.
|
||||||
|
|
||||||
|
Optimization barriers have no effect outside a compiled function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operand: a pytree of JAX values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A pytree of JAX values, with the same structure and contents as ``operand``.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Prevents common-subexpression elimination between the two calls to `sin`:
|
||||||
|
|
||||||
|
>>> def f(x):
|
||||||
|
... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x)
|
||||||
|
>>> jax.jit(f)(0.)
|
||||||
|
Array(0., dtype=float32, weak_type=True)
|
||||||
|
"""
|
||||||
|
flat_args, treedef = tree_util.tree_flatten(operand)
|
||||||
|
return tree_util.tree_unflatten(
|
||||||
|
treedef, optimization_barrier_p.bind(*flat_args))
|
||||||
|
|
||||||
|
|
||||||
|
def _optimization_barrier_abstract_eval(*args):
|
||||||
|
return args
|
||||||
|
|
||||||
|
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||||
|
barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in)
|
||||||
|
flat_args = mlir.flatten_ir_values(args)
|
||||||
|
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||||
|
return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types)
|
||||||
|
|
||||||
|
|
||||||
|
optimization_barrier_p = core.Primitive('optimization_barrier')
|
||||||
|
optimization_barrier_p.multiple_results = True
|
||||||
|
optimization_barrier_p.def_impl(
|
||||||
|
partial(dispatch.apply_primitive, optimization_barrier_p))
|
||||||
|
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
|
||||||
|
mlir.register_lowering(optimization_barrier_p,
|
||||||
|
_optimization_barrier_lowering_rule)
|
||||||
|
@ -142,6 +142,8 @@ from jax._src.lax.lax import (
|
|||||||
nextafter as nextafter,
|
nextafter as nextafter,
|
||||||
nextafter_p as nextafter_p,
|
nextafter_p as nextafter_p,
|
||||||
not_p as not_p,
|
not_p as not_p,
|
||||||
|
optimization_barrier as optimization_barrier,
|
||||||
|
optimization_barrier_p as optimization_barrier_p,
|
||||||
or_p as or_p,
|
or_p as or_p,
|
||||||
outfeed as outfeed,
|
outfeed as outfeed,
|
||||||
outfeed_p as outfeed_p,
|
outfeed_p as outfeed_p,
|
||||||
|
@ -3093,6 +3093,10 @@ class LaxTest(jtu.JaxTestCase):
|
|||||||
with jax.transfer_guard('disallow'):
|
with jax.transfer_guard('disallow'):
|
||||||
jax.jit(asarray_closure)()
|
jax.jit(asarray_closure)()
|
||||||
|
|
||||||
|
def testOptimizationBarrier(self):
|
||||||
|
x = lax.optimization_barrier((2, 3))
|
||||||
|
self.assertEqual((2, 3), x)
|
||||||
|
|
||||||
|
|
||||||
class LazyConstantTest(jtu.JaxTestCase):
|
class LazyConstantTest(jtu.JaxTestCase):
|
||||||
def _Check(self, make_const, expected):
|
def _Check(self, make_const, expected):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user