47 Commits

Author SHA1 Message Date
Sergei Lebedev
d4ced960ab Pulled DLDeviceType to XLA backend mapping into a global
I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.

PiperOrigin-RevId: 721052736
2025-01-29 11:38:50 -08:00
Sergei Lebedev
1289640f09 Deprecated calling `jax.dlpack.from_dlpack` with a DLPack tensor
PiperOrigin-RevId: 670723176
2024-09-03 15:16:02 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
rajasekharporeddy
aaddba0c20 Fix doc Typos 2024-04-22 10:32:51 +05:30
Meekail Zain
a2feff2e54 Add support for max_version, dl_device, copy kwargs in __dlpack__ 2024-04-11 16:44:19 +00:00
Meekail Zain
2b1c3deee2 Update from_dlpack to match array API 2023 2024-04-04 22:51:25 +00:00
Seunghoon Park
e00149c39f Fix unnecessary memory copies between GPU and CPU when jax2tf.call_tf() is used.
- The root cause of the bug is that dtype lookups are incorrect because hashes behave differently between dtype instances and their types. Added comments to `jax.dlpack.SUPPORTED_DTYPES` about this.
- Added unit test coverage.
- Fixing this bug revealed a limitation of causing "host-to-device" copy in the following two situations. See the details in the unit test comments.:
  - When the dtype is 'int32'.
  - When using PJRT C API runtime.

PiperOrigin-RevId: 610799558
2024-02-27 10:35:50 -08:00
Peter Hawkins
c4368351d2 Add support for bool dlpack values.
PiperOrigin-RevId: 599199196
2024-01-17 09:30:42 -08:00
Jake VanderPlas
df4e9c0d41 DOC: add warning about dlpack and buffer mutation 2024-01-03 13:31:57 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Jieying Luo
269d7ce5c1 Remove take_ownership support in DLPack.
When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack.

PiperOrigin-RevId: 572278488
2023-10-10 09:43:02 -07:00
Peter Hawkins
3a4b60b48c Fix dlpack type signatures to match Array API spec.
Fixes https://github.com/google/jax/issues/17510
2023-09-08 10:12:32 -04:00
Skye Wanderman-Milne
ecee8f9116 [JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.

This is the import path. The export path was implemented in
0b3cbfe4bc.

This allows for creating jax.Arrays from external GPU arrays
asynchronously.

PiperOrigin-RevId: 561172624
2023-08-29 16:39:31 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Skye Wanderman-Milne
a80cbc5626 [JAX] Implement the stream argument to jax.Array.__dlpack__ for CUDA GPU
Also implements jax.Array.__dlpack_device__. See
https://dmlc.github.io/dlpack/latest/python_spec.html

This requires plumbing the raw CUDA stream pointer through PJRT and
StreamExecutor (since the GPU PJRT implementation is still based on
SE). This is done via the new PJRT method
ExternalReference::WaitUntilBufferReadyOnStream.

I haven't plumbed this through the PJRT C API yet, because I'm still
debating whether this should be part of the main API or a GPU-specific
extension (plus either way it should probably be its own change).

PiperOrigin-RevId: 558245360
2023-08-18 14:20:38 -07:00
Peter Hawkins
3135fbcd7f [JAX] Delete _DeviceArray and DeviceArray.
PiperOrigin-RevId: 520453090
2023-03-29 15:07:14 -07:00
Skye Wanderman-Milne
00acf459c6 Bump minimum jaxlib version from 0.4.6 to 0.4.7.
Also removes a bunch of dead version guards (0.4.7 has
xla_extension_version 144 and mlir_api_version 47)
2023-03-28 13:43:01 -07:00
Yash Katariya
6d0189e810 Remove dispatch.result_handlers since they are not used.
PiperOrigin-RevId: 517456171
2023-03-17 11:02:22 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Parker Schuh
f888e4814c [Rollforward] Convert _arrays to return PyArray instead of PyBuffer.
This change also converts all callsites that construct buffers to
return PyArrays.

PiperOrigin-RevId: 510486273
2023-02-17 11:52:43 -08:00
Roy Frostig
d927a5dbf3 migrate internal dependencies from jax.core to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
2022-12-16 21:00:14 -08:00
Yash Katariya
4443b861a5 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental.
PiperOrigin-RevId: 492033543
2022-11-30 15:26:03 -08:00
Parker Schuh
fb4db5b60f Delete trailing whitespace that is blocking presubmits.
PiperOrigin-RevId: 489005596
2022-11-16 12:15:40 -08:00
Rahul Batra
4370b3385f [ROCm] Add dlpack backend support
Depends on the Tensorflow commit included in this
	PR https://github.com/tensorflow/tensorflow/pull/57640
2022-10-28 19:19:23 +00:00
Roy Frostig
4af3509819 canonicalize dtypes when loading arrays via dlpack 2022-10-20 15:20:34 -07:00
Yash Katariya
9e4114f0f1 Move array.py and sharding.py from experimental/ to _src/.
PiperOrigin-RevId: 477201711
2022-09-27 10:06:52 -07:00
Yash Katariya
cbf34cb609 Rename the concrete class Array to ArrayImpl
PiperOrigin-RevId: 477017236
2022-09-26 16:18:30 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
c8daaadd69 Avoid import-time dependency on jax.experimental 2022-08-19 11:30:25 -07:00
Yash Katariya
9244f3b1ba Add support for interoperability via dlpack for Array and also make pickle_tests and lax_numpy_test pass with Array.
PiperOrigin-RevId: 468568917
2022-08-18 16:04:22 -07:00
Peter Hawkins
931bf3674b [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.

[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.

PiperOrigin-RevId: 446737518
2022-05-05 09:33:06 -07:00
Peter Hawkins
94efc90939 Drop dead code now that the minimum jaxlib version is 0.3.2. 2022-04-13 13:34:00 -04:00
Peter Hawkins
80aec7b25f Documentation improvements. 2022-03-08 09:37:33 -05:00
Peter Hawkins
8af0d8d033 Add complex number DLPack support to JAX and TensorFlow.
Fixes https://github.com/google/jax/issues/9497

PiperOrigin-RevId: 427579098
2022-02-09 14:58:00 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
f885366a8b [JAX] Improve support for DLPack tensors on CPU when a GPU is available.
https://github.com/google/jax/issues/5581

Previously the user had to provide the target backend explicitly. Now we supply both CPU and GPU backends to the C++ code so it can choose based on the metadata of the DLPack tensor.

PiperOrigin-RevId: 380795192
2021-06-22 06:38:53 -07:00
Peter Hawkins
140c0acbbe Remove the JAX lazy sublanguage.
Back in the mists of time, before omnistaging landed in JAX, we used lazy
expressions to avoid materializing large constants inside `jit` computations.
Omnistaging, which means that computations that are in the dynamic scope of a
`jit` are staged into the `jit` computation, has subsumed most of the reasons
for laziness to exist, and this PR removes the laziness support for simplicity.

At the time of this PR, laziness is used only for broadcasts and transposes in
eager mode (i.e., outside a `jit`). This allows us to:
a) fuse together multiple broadcasts and transposes, and
b) if a lazy expression is lexically captured by a `jit` computation, we can
   avoid materializing it in its expanded form.

It is not clear that laziness has sufficient power to weight ratio to continue
to exist, and it is making other work on improving JAX dispatch times more
difficult. As a result, this PR removes laziness to unblock that work; if we
want laziness again we would want to reimplement it in C++ anyway.
2021-03-09 21:40:46 -05:00
Peter Hawkins
2469ad1bb3 Cleanups for laziness. No functional changes intended.
Use None as a trivial lazy expression in more places. Simplify some code.
2021-03-07 11:33:04 -05:00
George Necula
fd2b9b4759 [call_tf] Add support for DLPack to avoid copying arrays in eager mode 2021-02-13 12:49:51 +02:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
Jean-Baptiste Lespiau
ad2de75546 Add dynamically methods to _DeviceArray and PyBuffer.
As you suggested during a review, calling `_DeviceArray` directly when LazyExpr are present is clearer.
2020-12-03 14:18:11 +01:00
Jean-Baptiste Lespiau
e2fdceb3c8 Change PyBuffer.shape to be PyBuffer.xla_shape` in a backward compatible way.
We need this as we will update a new Jaxlib with `shape` returning a tuple, and as the submission process in in 2 steps, we need this before updating xla.cc
2020-11-04 01:39:07 +01:00
Jean-Baptiste Lespiau
3e5a0ff0c4 Add methods to interact with DeviceArray objects.
We are going to add a C++ implementation, this is a useful refectoring to ease the transition. In short,

- `isinstance(x, DeviceArray)` will continue to work
- type(x) is DeviceArray will be replaced by type_is_device_array(x)
- DeviceArray(...) constructor will be replaced by get_device_array.
2020-11-03 22:16:28 +01:00
Peter Hawkins
d001ac6b8a [JAX] Add support for retaining ownership of DLPack tensors.
Move dlpack.py contents under jax/_src/dlpack.py.

Add array interoperability test between JAX and TensorFlow using DLPack.

Fixes: https://github.com/google/jax/issues/4636
PiperOrigin-RevId: 338120910
2020-10-20 13:07:07 -07:00