mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Simplify JAX lowering rules for cumulative sum
Upstream fix has landed => removing CPU workaround. PiperOrigin-RevId: 635505632
This commit is contained in:
parent
cc3a380f76
commit
e0a6453a39
@ -2449,16 +2449,11 @@ def _cumulative_reduction_primitive(name, reduce_fn, reduce_window_fn):
|
||||
mlir.cache_lowering(mlir.lower_fun(fn, multiple_results=False)),
|
||||
platform=platform)
|
||||
|
||||
if xla_extension_version >= 263:
|
||||
if xla_extension_version >= 266:
|
||||
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
|
||||
# TODO(https://github.com/llvm/llvm-project/issues/91883): enable rewrite
|
||||
# for CPU once the vectorizer crash is fixed..
|
||||
for platform in ['cuda', 'rocm', 'tpu']:
|
||||
register_lowering(
|
||||
partial(cumred_reduce_window_impl, reduce_window_fn), platform
|
||||
)
|
||||
|
||||
register_lowering(partial(associative_scan, reduce_fn))
|
||||
register_lowering(
|
||||
partial(cumred_reduce_window_impl, reduce_window_fn)
|
||||
)
|
||||
else:
|
||||
# Older XLA versions only have this rewrite for TPU.
|
||||
register_lowering(
|
||||
|
Loading…
x
Reference in New Issue
Block a user