7 Commits

Author SHA1 Message Date
Peter Hawkins
e58f1ba86e Move some utilities out of dispatch.py next to their users, add more types.
Internal cleanups only, no user-visible changes intended.

PiperOrigin-RevId: 554876522
2023-08-08 10:52:11 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
9f083d11da Use jax.* APIs rather than api.* names in tests.
Tests should use our own public APIs where they exist.
2021-09-13 16:01:32 -04:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Qiao Zhang
850bd66242 [JAX] Prune unused inputs in jit.
- Python part based on: https://github.com/google/jax/pull/6567
- Added cpp_jit path to handle pruned args

PiperOrigin-RevId: 371743277
2021-05-03 11:41:29 -07:00