19 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
c8daaadd69 Avoid import-time dependency on jax.experimental 2022-08-19 11:30:25 -07:00
Yash Katariya
9244f3b1ba Add support for interoperability via dlpack for Array and also make pickle_tests and lax_numpy_test pass with Array.
PiperOrigin-RevId: 468568917
2022-08-18 16:04:22 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Peter Hawkins
80aec7b25f Documentation improvements. 2022-03-08 09:37:33 -05:00
Peter Hawkins
8af0d8d033 Add complex number DLPack support to JAX and TensorFlow.
Fixes https://github.com/google/jax/issues/9497

PiperOrigin-RevId: 427579098
2022-02-09 14:58:00 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
f885366a8b [JAX] Improve support for DLPack tensors on CPU when a GPU is available.
https://github.com/google/jax/issues/5581

Previously the user had to provide the target backend explicitly. Now we supply both CPU and GPU backends to the C++ code so it can choose based on the metadata of the DLPack tensor.

PiperOrigin-RevId: 380795192
2021-06-22 06:38:53 -07:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00
George Necula
fd2b9b4759 [call_tf] Add support for DLPack to avoid copying arrays in eager mode 2021-02-13 12:49:51 +02:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
Jean-Baptiste Lespiau
ad2de75546 Add dynamically methods to _DeviceArray and PyBuffer.
As you suggested during a review, calling `_DeviceArray` directly when LazyExpr are present is clearer.
2020-12-03 14:18:11 +01:00
Jean-Baptiste Lespiau
e2fdceb3c8 Change PyBuffer.shape to be PyBuffer.xla_shape` in a backward compatible way.
We need this as we will update a new Jaxlib with `shape` returning a tuple, and as the submission process in in 2 steps, we need this before updating xla.cc
2020-11-04 01:39:07 +01:00
Jean-Baptiste Lespiau
3e5a0ff0c4 Add methods to interact with DeviceArray objects.
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,

- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Peter Hawkins
d001ac6b8a [JAX] Add support for retaining ownership of DLPack tensors.
Move dlpack.py contents under jax/_src/dlpack.py.

Add array interoperability test between JAX and TensorFlow using DLPack.

Fixes: https://github.com/google/jax/issues/4636
PiperOrigin-RevId: 338120910
2020-10-20 13:07:07 -07:00