mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Merge pull request #24913 from hawkinsp:threefry
PiperOrigin-RevId: 696915844
This commit is contained in:
commit
d8085008b7
@ -891,9 +891,10 @@ def _threefry2x32_lowering(key1, key2, x1, x2, use_rolled_loops=True):
|
||||
return tuple(x)
|
||||
|
||||
|
||||
_threefry2x32_lowering_rule = mlir.lower_fun(
|
||||
# Since the unrolled lowering is large, emit it as an out-of-line function.
|
||||
_threefry2x32_lowering_rule = mlir.cache_lowering(mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=False),
|
||||
multiple_results=True)
|
||||
multiple_results=True))
|
||||
|
||||
_threefry2x32_cpu_lowering_rule = mlir.lower_fun(
|
||||
partial(_threefry2x32_lowering, use_rolled_loops=True),
|
||||
|
Loading…
x
Reference in New Issue
Block a user