Let caller switch implementation of reduction after import

Thank you to gnecula@ for adding the jax2tf_associative_scan_reductions flag and context: 5bfe1852a4
For GPU, the specific implementation of `cumsum` can make the whopping difference between a latency in microseconds versus milliseconds!

Before this change, adjusting the method of lowering `cumsum` via this scope has no effect:

```py
with jax.jax2tf_associative_scan_reductions(True):
  ...
```

... because the cumsum method (and other reduce methods) have their implementations set when the `jax2tf` library is imported, ie when this line is called:

```py
from jax.experimental import jax2tf
```

Thus, any future switches of the implementation (to, say associative scanning), even if they happen before the `jax2tf.convert` method executes, had no effect because methods such as `cumsum` had already been curried at import time.

This change fixes that by varying the implementation based on the current value of `config.jax2tf_associative_scan_reductions`.

We use existing tests to verify the continued correctness of this CL that affects latency. We add TPU to the list of devices to apply some limitations - One TPU unit test had suddenly failed because the scope now works: Even though TPUs use a different path to lower by default, the context above explicitly sets to associative scanning.

PiperOrigin-RevId: 624264567
This commit is contained in:
Chi Zeng 2024-04-12 12:47:07 -07:00 committed by jax authors
parent 8b691d15a8
commit 9a89a0cee8
2 changed files with 25 additions and 21 deletions

View File

@ -2621,16 +2621,21 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
def _cumred(lax_reduce_fn: Callable,
lax_reduce_window_fn: Callable,
extra_name_stack: str):
if config.jax2tf_associative_scan_reductions.value:
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
lax_reduce_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
else:
return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl,
lax_reduce_window_fn),
multiple_results=False,
extra_name_stack=extra_name_stack)
associative_scan = partial(lax_control_flow.associative_scan, lax_reduce_fn)
reduce_window = partial(
lax_control_flow.cumred_reduce_window_impl, lax_reduce_window_fn
)
def _call_impl(*args, **kwargs):
# Vary which implementation to use when cumulation is called. This cannot be
# done during import time because the caller may later use a python context
# to switch the implementation to use.
associative = config.jax2tf_associative_scan_reductions.value
return (associative_scan if associative else reduce_window)(*args, **kwargs)
return _convert_jax_impl(
_call_impl, multiple_results=False, extra_name_stack=extra_name_stack
)
tf_impl_with_avals[lax.cummax_p] = _cumred(

View File

@ -389,35 +389,34 @@ class Jax2TfLimitation(test_harnesses.Limitation):
@classmethod
def cumlogsumexp(cls, harness):
return [
# JAX uses a different lowering for CPU and GPU.
custom_numeric(
dtypes=(np.float16, jnp.bfloat16),
devices=("cpu", "gpu"),
dtypes=(np.float16, jnp.bfloat16, np.float32),
devices=("cpu", "gpu", "tpu"),
modes=("eager", "graph", "compiled"),
tol=5e-1)
tol=5e-1,
)
]
@classmethod
def cumprod(cls, harness):
return [
# JAX uses a different lowering for CPU and GPU.
custom_numeric(
dtypes=(np.float16, jnp.bfloat16),
devices=("cpu", "gpu"),
devices=("cpu", "gpu", "tpu"),
modes=("eager", "graph", "compiled"),
tol=5e-1)
tol=5e-1,
)
]
@classmethod
def cumsum(cls, harness):
return [
# JAX uses a different lowering for CPU and GPU.
custom_numeric(
dtypes=(np.float16, jnp.bfloat16),
devices=("cpu", "gpu"),
devices=("cpu", "gpu", "tpu"),
modes=("eager", "graph", "compiled"),
tol=5e-1)
tol=5e-1,
)
]
@classmethod