[jax2tf] Add jax2tf_associative_scan_reductions flag

This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
This commit is contained in:
George Necula 2022-01-13 12:28:25 +02:00
parent f0e4f0472d
commit 5bfe1852a4
8 changed files with 98 additions and 42 deletions

View File

@ -33,6 +33,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path.
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`#7987`).
* jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}`#9189`).
## jaxlib 0.1.76 (Unreleased)
* New features

View File

@ -49,6 +49,7 @@ from jax._src.config import (
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion,
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.api import (

View File

@ -473,6 +473,22 @@ flags.DEFINE_bool(
)
)
# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = config.define_bool_state(
name='jax2tf_associative_scan_reductions',
default=False,
help=(
'JAX has two separate lowering rules for the cumulative reduction '
'primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses '
'a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. '
'The latter has a slow implementation on CPUs and GPUs. '
'By default, jax2tf uses the TPU lowering. Set this flag to True to '
'use the associative scan lowering usage, and only if it makes a difference '
'for your application. '
'See the jax2tf README.md for more details.'
)
)
enable_checks = config.define_bool_state(
name='jax_enable_checks',
default=False,

View File

@ -759,6 +759,27 @@ jax2tf.convert(jax_fun)(3.14)
jax2tf.convert(jax_fun)(tf.Variable(3.14, dtype=jax2tf.dtype_of_val(3.14))
```
### Slow implementation of associative reductions for CPU
Operations like ``jax.numpy.cumsum`` are compiled by JAX differently based
on the platform. For TPU, the compilation uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)
operation, which has an efficient implementation for the cases when the
reduction function is associative. For CPU and GPU, JAX uses an alternative
implementation using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801).
jax2tf uses the TPU lowering (because it does not support backend-specific lowering)
and hence it can be slow in some cases on CPU and GPU.
We have filed a bug with the XLA:CPU compiler to improve ReduceWindow.
Meanwhile, if you run into this problem you can use the
``--jax2tf_associative_scan_reductions`` flag to get the special
associative scan lowering.
You can alternatively use the ``with jax.jax2tf_associative_scan_reductions(True)``
around the code that invokes the function returned by ``jax2tf.convert``.
Use this only if it improves the performance for your application.
Note that this lowering may not work as well as the default one in presence
of shape polymorphism.
### Unchecked assumption that the dimension variables take strictly positive values
The shape polymorphic conversion is sound with the assumption that the dimension

View File

@ -1950,31 +1950,33 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
# cummin, cumsum and cumprod. This is efficient on TPU, but the complexity is
# O(n^2) on other backends. This may be implemented using associative_scan
# instead to favor different backends.
tf_impl_with_avals[lax.cummin_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min),
multiple_results=False,
extra_name_stack="cummin")
tf_impl_with_avals[lax.cummax_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max),
multiple_results=False,
extra_name_stack="cummin")
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
# certain dtypes: bfloat16, float16, float32, float64, and int32. Other dtypes
# will fail when running in compiled mode, but are otherwise compatible with
# the operation. A non-XLA path can thus be defined for all dtypes, though the
# tests will crash.
tf_impl_with_avals[lax.cumsum_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum),
multiple_results=False,
extra_name_stack="cumsum")
tf_impl_with_avals[lax.cumprod_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod),
multiple_results=False,
extra_name_stack="cumprod")
def _cumred(lax_reduce_fn: Callable,
lax_reduce_window_fn: Callable,
extra_name_stack: str):
if config.jax2tf_associative_scan_reductions:
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
lax_reduce_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
else:
return _convert_jax_impl(partial(lax_control_flow._cumred_tpu_translation_rule,
lax_reduce_window_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
tf_impl_with_avals[lax.cummax_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_max,
lax_reduce_fn=lax.max,
extra_name_stack="cummax")
tf_impl_with_avals[lax.cummin_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_min,
lax_reduce_fn=lax.min,
extra_name_stack="cummin")
tf_impl_with_avals[lax.cumsum_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_sum,
lax_reduce_fn=lax.add,
extra_name_stack="cumsum")
tf_impl_with_avals[lax.cumprod_p] = _cumred(lax_reduce_window_fn=lax._reduce_window_prod,
lax_reduce_fn=lax.mul,
extra_name_stack="cumprod")
def _select_and_scatter(operand, source, init_value, select_jaxpr,

View File

@ -300,28 +300,23 @@ class Jax2TfLimitation(primitive_harness.Limitation):
@classmethod
def cumprod(cls, harness):
return [
# TODO: very high tolerance
# JAX uses a different lowering for CPU and GPU.
custom_numeric(
dtypes=np.float16,
dtypes=(np.float16, jnp.bfloat16),
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled"),
tol=1e-1)
tol=5e-1)
]
@classmethod
def cumsum(cls, harness):
return [
# TODO: very high tolerance
# JAX uses a different lowering for CPU and GPU.
custom_numeric(
dtypes=np.float16,
tol=0.1,
dtypes=(np.float16, jnp.bfloat16),
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")),
custom_numeric(
dtypes=dtypes.bfloat16,
tol=0.5,
devices=("cpu", "gpu"),
modes=("eager", "graph", "compiled")),
modes=("eager", "graph", "compiled"),
tol=5e-1)
]
@classmethod

View File

@ -1411,8 +1411,23 @@ def _make_cumreduce_harness(name,
shape=shape,
dtype=dtype,
axis=axis,
reverse=reverse)
reverse=reverse,
associative_scan_reductions=False
)
define(
f_jax.__name__,
f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_associative_scan_reductions_axis={axis}_reverse={reverse}",
f_jax, [RandArg(shape, dtype),
StaticArg(axis),
StaticArg(reverse)],
jax_unimplemented=limitations,
f_jax=f_jax,
shape=shape,
dtype=dtype,
axis=axis,
reverse=reverse,
associative_scan_reductions=True
)
# Validate dtypes for each function
for f_jax in [

View File

@ -100,7 +100,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="mode=GatherScatterMode.CLIP"
#one_containing="cumprod_dtype_by_fun_shape=float16[8,9]_axis=0_reverse=False"
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")
@ -112,8 +112,10 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
func_jax = harness.dyn_fun
args = harness.dyn_args_maker(self.rng())
enable_xla = harness.params.get("enable_xla", True)
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
def test_primitive_coverage(self):
"""Fail if there are JAX primitives that are not implemented."""