mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Let caller switch implementation of reduction after import
Thank you to gnecula@ for adding the jax2tf_associative_scan_reductions flag and context: 5bfe1852a4
For GPU, the specific implementation of `cumsum` can make the whopping difference between a latency in microseconds versus milliseconds!
Before this change, adjusting the method of lowering `cumsum` via this scope has no effect:
```py
with jax.jax2tf_associative_scan_reductions(True):
...
```
... because the cumsum method (and other reduce methods) have their implementations set when the `jax2tf` library is imported, ie when this line is called:
```py
from jax.experimental import jax2tf
```
Thus, any future switches of the implementation (to, say associative scanning), even if they happen before the `jax2tf.convert` method executes, had no effect because methods such as `cumsum` had already been curried at import time.
This change fixes that by varying the implementation based on the current value of `config.jax2tf_associative_scan_reductions`.
We use existing tests to verify the continued correctness of this CL that affects latency. We add TPU to the list of devices to apply some limitations - One TPU unit test had suddenly failed because the scope now works: Even though TPUs use a different path to lower by default, the context above explicitly sets to associative scanning.
PiperOrigin-RevId: 624264567
This commit is contained in:
parent
8b691d15a8
commit
9a89a0cee8
@ -2621,16 +2621,21 @@ tf_impl_with_avals[lax.reduce_p] = _reduce
|
||||
def _cumred(lax_reduce_fn: Callable,
|
||||
lax_reduce_window_fn: Callable,
|
||||
extra_name_stack: str):
|
||||
if config.jax2tf_associative_scan_reductions.value:
|
||||
return _convert_jax_impl(partial(lax_control_flow.associative_scan,
|
||||
lax_reduce_fn),
|
||||
multiple_results=False,
|
||||
extra_name_stack=extra_name_stack)
|
||||
else:
|
||||
return _convert_jax_impl(partial(lax_control_flow.cumred_reduce_window_impl,
|
||||
lax_reduce_window_fn),
|
||||
multiple_results=False,
|
||||
extra_name_stack=extra_name_stack)
|
||||
associative_scan = partial(lax_control_flow.associative_scan, lax_reduce_fn)
|
||||
reduce_window = partial(
|
||||
lax_control_flow.cumred_reduce_window_impl, lax_reduce_window_fn
|
||||
)
|
||||
|
||||
def _call_impl(*args, **kwargs):
|
||||
# Vary which implementation to use when cumulation is called. This cannot be
|
||||
# done during import time because the caller may later use a python context
|
||||
# to switch the implementation to use.
|
||||
associative = config.jax2tf_associative_scan_reductions.value
|
||||
return (associative_scan if associative else reduce_window)(*args, **kwargs)
|
||||
|
||||
return _convert_jax_impl(
|
||||
_call_impl, multiple_results=False, extra_name_stack=extra_name_stack
|
||||
)
|
||||
|
||||
|
||||
tf_impl_with_avals[lax.cummax_p] = _cumred(
|
||||
|
@ -389,35 +389,34 @@ class Jax2TfLimitation(test_harnesses.Limitation):
|
||||
@classmethod
|
||||
def cumlogsumexp(cls, harness):
|
||||
return [
|
||||
# JAX uses a different lowering for CPU and GPU.
|
||||
custom_numeric(
|
||||
dtypes=(np.float16, jnp.bfloat16),
|
||||
devices=("cpu", "gpu"),
|
||||
dtypes=(np.float16, jnp.bfloat16, np.float32),
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
tol=5e-1)
|
||||
tol=5e-1,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@classmethod
|
||||
def cumprod(cls, harness):
|
||||
return [
|
||||
# JAX uses a different lowering for CPU and GPU.
|
||||
custom_numeric(
|
||||
dtypes=(np.float16, jnp.bfloat16),
|
||||
devices=("cpu", "gpu"),
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
tol=5e-1)
|
||||
tol=5e-1,
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def cumsum(cls, harness):
|
||||
return [
|
||||
# JAX uses a different lowering for CPU and GPU.
|
||||
custom_numeric(
|
||||
dtypes=(np.float16, jnp.bfloat16),
|
||||
devices=("cpu", "gpu"),
|
||||
devices=("cpu", "gpu", "tpu"),
|
||||
modes=("eager", "graph", "compiled"),
|
||||
tol=5e-1)
|
||||
tol=5e-1,
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
Loading…
x
Reference in New Issue
Block a user