5 Commits

Author SHA1 Message Date
Peter Hawkins
64eae324ee Migrate JAX MLIR Python dialect extensions to nanobind.
Now that https://github.com/llvm/llvm-project/pull/117922 has landed upstream, we can work towards removing our last uses of pybind11.

PiperOrigin-RevId: 705872751
2024-12-13 07:08:28 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Aden Grue
c368969955 Use the new "custom call status" facility to report errors in jaxlib
PiperOrigin-RevId: 389734200
2021-08-09 15:06:39 -07:00
Skye Wanderman-Milne
7a154f71bc
Fix jaxlib build by not exposing nvcc to pybind11. (#1819) 2019-12-05 18:59:29 -08:00
Peter Hawkins
534d812b57
Add a handwritten ThreeFry2x32 CUDA kernel. (#1756)
In principle, JAX should not need a hand-written CUDA kernel for the ThreeFry2x32 algorithm. In practice XLA aggresively inlines, which causes compilation times on GPU blow up when compiling potentially many copies of the PRNG kernel in a program. As a workaround, we add a hand-written CUDA kernel mostly to reduce compilation time.

When XLA becomes smarter about compiling this particular hash function, we should be able to remove the hand-written kernel once again.
2019-11-24 13:06:23 -05:00