rocm_jax/jax/experimental
Roy Frostig 3f9540761e reintroduce the Threefry GPU kernel lowering, under a flag
On GPU, the Threefry PRNG implementation no longer lowers to a kernel call by default. This choice can improve runtime memory usage at a compile-time cost. Prior behavior, which produces a kernel call, can be recovered with:

   `jax.config.update('jax_threefry_gpu_kernel_lowering', True)`

PiperOrigin-RevId: 629763763
2024-05-01 10:33:31 -07:00
..
2024-04-05 20:09:34 -07:00
2024-03-22 20:10:18 -07:00
2023-12-18 10:08:47 -08:00
2024-04-22 14:29:35 -07:00