mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for associative reductions in CPU. See README.md for details.
This commit is contained in:
parent
f0e4f0472d
commit
5bfe1852a4
@ -33,6 +33,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
|
||||
file under the given path.
|
||||
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
|
||||
* jax2tf now supports a flag jax2tf_associative_scan_reductions to change
|
||||
the lowering for associative reductions, e.g., jnp.cumsum, to behave
|
||||
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
|
||||
for more details ({jax-issue}`#9189`).
|
||||
|
||||
## jaxlib 0.1.76 (Unreleased)
|
||||
* New features
|
||||
|
@ -49,6 +49,7 @@ from jax._src.config import (
|
||||
default_matmul_precision as default_matmul_precision,
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.api import (
|
||||
|
@ -473,6 +473,22 @@ flags.DEFINE_bool(
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
||||
jax2tf_associative_scan_reductions = config.define_bool_state(
|
||||
name='jax2tf_associative_scan_reductions',
|
||||
default=False,
|
||||
help=(
|
||||
'JAX has two separate lowering rules for the cumulative reduction '
|
||||
'primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses '
|
||||
'a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. '
|
||||
'The latter has a slow implementation on CPUs and GPUs. '
|
||||
'By default, jax2tf uses the TPU lowering. Set this flag to True to '
|
||||
'use the associative scan lowering usage, and only if it makes a difference '
|
||||
'for your application. '
|
||||
'See the jax2tf README.md for more details.'
|
||||
)
|
||||
)
|
||||
|
||||
enable_checks = config.define_bool_state(
|
||||
name='jax_enable_checks',
|
||||
default=False,
|
||||
|
@ -759,6 +759,27 @@ jax2tf.convert(jax_fun)(3.14)
|
||||
jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14))
|
||||
```
|
||||
|
||||
### Slow implementation of associative reductions for CPU
|
||||
|
||||
Operations like ``jax.numpy.cumsum`` are compiled by JAX differently based
|
||||
on the platform. For TPU, the compilation uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)
|
||||
operation, which has an efficient implementation for the cases when the
|
||||
reduction function is associative. For CPU and GPU, JAX uses an alternative
|
||||
implementation using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
|
||||
jax2tf uses the TPU lowering (because it does not support backend-specific lowering)
|
||||
and hence it can be slow in some cases on CPU and GPU.
|
||||
|
||||
We have filed a bug with the XLA:CPU compiler to improve ReduceWindow.
|
||||
Meanwhile, if you run into this problem you can use the
|
||||
``--jax2tf_associative_scan_reductions`` flag to get the special
|
||||
associative scan lowering.
|
||||
You can alternatively use the ``with jax.jax2tf_associative_scan_reductions(True)``
|
||||
around the code that invokes the function returned by ``jax2tf.convert``.
|
||||
Use this only if it improves the performance for your application.
|
||||
|
||||
Note that this lowering may not work as well as the default one in presence
|
||||
of shape polymorphism.
|
||||
|
||||
### Unchecked assumption that the dimension variables take strictly positive values
|
||||
|
||||
The shape polymorphic conversion is sound with the assumption that the dimension
|
||||
|
@ -1950,31 +1950,33 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
|
||||
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
|
||||
# O(n^2) on other backends. This may be implemented using associative_scan
|
||||
# instead to favor different backends.
|
||||
tf_impl_with_avals[lax.cummin_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_min),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cummin")
|
||||
tf_impl_with_avals[lax.cummax_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_max),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cummin")
|
||||
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
|
||||
# certain dtypes: bfloat16, float16, float32, float64, and int32. Other dtypes
|
||||
# will fail when running in compiled mode, but are otherwise compatible with
|
||||
# the operation. A non-XLA path can thus be defined for all dtypes, though the
|
||||
# tests will crash.
|
||||
tf_impl_with_avals[lax.cumsum_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_sum),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cumsum")
|
||||
tf_impl_with_avals[lax.cumprod_p] = _convert_jax_impl(
|
||||
partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax._reduce_window_prod),
|
||||
multiple_results=False,
|
||||
extra_name_stack="cumprod")
|
||||
def _cumred(lax_reduce_fn: Callable,
|
||||
lax_reduce_window_fn: Callable,
|
||||
extra_name_stack: str):
|
||||
if config.jax2tf_associative_scan_reductions:
|
||||
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
|
||||
lax_reduce_fn),
|
||||
multiple_results=False,
|
||||
extra_name_stack=extra_name_stack)
|
||||
else:
|
||||
return _convert_jax_impl(partial(lax_control_flow._cumred_tpu_translation_rule,
|
||||
lax_reduce_window_fn),
|
||||
multiple_results=False,
|
||||
extra_name_stack=extra_name_stack)
|
||||
|
||||
|
||||
tf_impl_with_avals[lax.cummax_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_max,
|
||||
lax_reduce_fn=lax.max,
|
||||
extra_name_stack="cummax")
|
||||
tf_impl_with_avals[lax.cummin_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_min,
|
||||
lax_reduce_fn=lax.min,
|
||||
extra_name_stack="cummin")
|
||||
tf_impl_with_avals[lax.cumsum_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_sum,
|
||||
lax_reduce_fn=lax.add,
|
||||
extra_name_stack="cumsum")
|
||||
tf_impl_with_avals[lax.cumprod_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_prod,
|
||||
lax_reduce_fn=lax.mul,
|
||||
extra_name_stack="cumprod")
|
||||
|
||||
|
||||
def _select_and_scatter(operand, source, init_value, select_jaxpr,
|
||||
|
@ -300,28 +300,23 @@ class Jax2TfLimitation(primitive_harness.Limitation):
|
||||
@classmethod
|
||||
def cumprod(cls, harness):
|
||||
return [
|
||||
# TODO: very high tolerance
|
||||
# JAX uses a different lowering for CPU and GPU.
|
||||
custom_numeric(
|
||||
dtypes=np.float16,
|
||||
dtypes=(np.float16, jnp.bfloat16),
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
tol=1e-1)
|
||||
tol=5e-1)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def cumsum(cls, harness):
|
||||
return [
|
||||
# TODO: very high tolerance
|
||||
# JAX uses a different lowering for CPU and GPU.
|
||||
custom_numeric(
|
||||
dtypes=np.float16,
|
||||
tol=0.1,
|
||||
dtypes=(np.float16, jnp.bfloat16),
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
custom_numeric(
|
||||
dtypes=dtypes.bfloat16,
|
||||
tol=0.5,
|
||||
devices=("cpu", "gpu"),
|
||||
modes=("eager", "graph", "compiled")),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
tol=5e-1)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -1411,8 +1411,23 @@ def _make_cumreduce_harness(name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
axis=axis,
|
||||
reverse=reverse)
|
||||
|
||||
reverse=reverse,
|
||||
associative_scan_reductions=False
|
||||
)
|
||||
define(
|
||||
f_jax.__name__,
|
||||
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_associative_scan_reductions_axis={axis}_reverse={reverse}",
|
||||
f_jax, [RandArg(shape, dtype),
|
||||
StaticArg(axis),
|
||||
StaticArg(reverse)],
|
||||
jax_unimplemented=limitations,
|
||||
f_jax=f_jax,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
axis=axis,
|
||||
reverse=reverse,
|
||||
associative_scan_reductions=True
|
||||
)
|
||||
|
||||
# Validate dtypes for each function
|
||||
for f_jax in [
|
||||
|
@ -100,7 +100,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
@primitive_harness.parameterized(
|
||||
primitive_harness.all_harnesses,
|
||||
include_jax_unimpl=False,
|
||||
#one_containing="mode=GatherScatterMode.CLIP"
|
||||
#one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False"
|
||||
)
|
||||
@jtu.ignore_warning(
|
||||
category=UserWarning, message="Using reduced precision for gradient.*")
|
||||
@ -112,8 +112,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
|
||||
func_jax = harness.dyn_fun
|
||||
args = harness.dyn_args_maker(self.rng())
|
||||
enable_xla = harness.params.get("enable_xla", True)
|
||||
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
|
||||
enable_xla=enable_xla)
|
||||
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
|
||||
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
|
||||
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
|
||||
enable_xla=enable_xla)
|
||||
|
||||
def test_primitive_coverage(self):
|
||||
"""Fail if there are JAX primitives that are not implemented."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user