1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Merge pull request from hawkinsp:threefry

PiperOrigin-RevId: 696915844
This commit is contained in:
jax authors 2024-11-15 09:42:54 -08:00
commit d8085008b7

@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
return tuple(x)
_threefry2x32_lowering_rule = mlir.lower_fun(
# Since the unrolled lowering is large, emit it as an out-of-line function.
_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=False),
multiple_results=True)
multiple_results=True))
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
partial(_threefry2x32_lowering, use_rolled_loops=True),