mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Lower threefry as an out-of-line MLIR function on TPU.
On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
This commit is contained in:
parent
1471702adc
commit
23e9142d28
@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
return tuple(x)
|
||||
|
||||
|
||||
_threefry2x32_lowering_rule = mlir.lower_fun(
|
||||
# Since the unrolled lowering is large, emit it as an out-of-line function.
|
||||
_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True)
|
||||
multiple_results=True))
|
||||
|
||||
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
|
Loading…
x
Reference in New Issue
Block a user