Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
e1410bd16b
Use lowering as impl rule for pure_callback
2022-09-01 15:29:31 -07:00
Sharad Vikram
311a9cb5d9
Throw error when 64-bit dtypes used incorrectly in jax.pure_callback
2022-08-31 12:31:04 -07:00
Yash Katariya
6340952e2a
Make jit == pjit. This means that the lowering and execution paths of jit and pjit are merged.
...
A fallback to `lower_xla_callable` is taken when pmap appears in the jaxpr during the jit lowering path.
Added support for `keep_unused`, `committed` and `core.Token` to pxla.py.
PiperOrigin-RevId: 470896270
2022-08-29 22:03:21 -07:00
Sharad Vikram
b0fdf10a63
Apply suggestions from code review
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d
Expose pure callback and enable rank polymorphic callbacks
2022-08-17 10:56:42 -07:00
Sharad Vikram
88f2b5e86d
Add functionality for "pure" callbacks
...
Also avoids using CPP dispatch path when host callbacks are involved
PiperOrigin-RevId: 467270949
2022-08-12 12:39:53 -07:00