Lower threefry as an out-of-line MLIR function on TPU.

On TPU we're using an unrolled version of this function, and its expansion is large. It makes sense to emit it as few times as possible to reduce code size.
This commit is contained in:
Peter Hawkins 2024-11-15 08:49:35 -08:00
parent 1471702adc
commit 23e9142d28

View File

@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
_threefry2x32_lowering_rule = mlir.lower_fun(
# Since the unrolled lowering is large, emit it as an out-of-line function.
_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True)
multiple_results=True))
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),