rocm_jax/jax/experimental
Sergei Lebedev a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
..
2024-06-26 16:10:18 -04:00
2024-06-26 16:10:18 -04:00
2023-12-18 10:08:47 -08:00
2024-05-19 21:01:29 +01:00