7 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
e1410bd16b Use lowering as impl rule for pure_callback 2022-09-01 15:29:31 -07:00
Sharad Vikram
311a9cb5d9 Throw error when 64-bit dtypes used incorrectly in jax.pure_callback 2022-08-31 12:31:04 -07:00
Yash Katariya
6340952e2a Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.

Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.

PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Sharad Vikram
b0fdf10a63 Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
Sharad Vikram
88f2b5e86d Add functionality for "pure" callbacks
Also avoids using CPP dispatch path when host callbacks are involved

PiperOrigin-RevId: 467270949
2022-08-12 12:39:53 -07:00