Simplify JAX lowering rules for cumulative sum

Upstream fix has landed => removing CPU workaround.

PiperOrigin-RevId: 635505632
This commit is contained in:
George Karpenkov 2024-05-20 10:51:27 -07:00 committed by jax authors
parent cc3a380f76
commit e0a6453a39

View File

@ -2449,16 +2449,11 @@ def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
platform=platform)
if xla_extension_version >= 263:
if xla_extension_version >= 266:
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
# TODO(https://github.com/llvm/llvm-project/issues/91883): enable rewrite
# for CPU once the vectorizer crash is fixed..
for platform in ['cuda', 'rocm', 'tpu']:
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn), platform
)
register_lowering(partial(associative_scan, reduce_fn))
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn)
)
else:
# Older XLA versions only have this rewrite for TPU.
register_lowering(