16 Commits

Author SHA1 Message Date
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Jieying Luo
1b1c6e7c0f Enable some more C API tests.
PiperOrigin-RevId: 627065492
2024-04-22 09:38:59 -07:00
Samuel Agyakwa
21a874b0bc [PJRT C API] Enable GPU Plugin tests internally
PiperOrigin-RevId: 592360226
2023-12-19 15:26:35 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
e6a62fcd11 [PJRT] Split the GpuId() platform constants into CudaId()/RocmId().
Similarly for the GpuName() constant.

While most of the time we treat CUDA and ROCm GPUs identically, we sometimes want to distinguish between CUDA and ROCm (e.g., for DLPack exports) and it's helpful if this is encoded in the platform ID.

PiperOrigin-RevId: 569513495
2023-09-29 09:35:16 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Skye Wanderman-Milne
ecee8f9116 [JAX] Implement importing external dlpack-aware Python arrays.
See https://dmlc.github.io/dlpack/latest/python_spec.html.

This is the import path. The export path was implemented in
0b3cbfe4bc.

This allows for creating jax.Arrays from external GPU arrays
asynchronously.

PiperOrigin-RevId: 561172624
2023-08-29 16:39:31 -07:00
Matthew Johnson
78199bbc32 [custom-jvp/vjp] for symbolic zeros, ensure rules can be run more than once
Co-authored-by: Roy Frostig <frostig@google.com>
2023-08-21 15:28:43 -07:00
Skye Wanderman-Milne
a80cbc5626 [JAX] Implement the stream argument to jax.Array.__dlpack__ for CUDA GPU
Also implements jax.Array.__dlpack_device__. See
https://dmlc.github.io/dlpack/latest/python_spec.html

This requires plumbing the raw CUDA stream pointer through PJRT and
StreamExecutor (since the GPU PJRT implementation is still based on
SE). This is done via the new PJRT method
ExternalReference::WaitUntilBufferReadyOnStream.

I haven't plumbed this through the PJRT C API yet, because I'm still
debating whether this should be part of the main API or a GPU-specific
extension (plus either way it should probably be its own change).

PiperOrigin-RevId: 558245360
2023-08-18 14:20:38 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Yash Katariya
aa5e229027 Bump minimum jaxlib version to 0.4.4 which means xla_extension_version >= 127
PiperOrigin-RevId: 512173011
2023-02-24 15:05:44 -08:00
Peter Hawkins
2f80e46f64 [XLA:Python] Fix overly pessimistic handling of singleton dimensions in dlpack code.
Requires an accompanying jaxlib change.

Fixes https://github.com/google/jax/issues/14399

PiperOrigin-RevId: 508757315
2023-02-10 14:44:22 -08:00
Peter Hawkins
6ee67639e2 Split PyTorch interoperability tests into their own test.
PiperOrigin-RevId: 508722180
2023-02-10 12:17:11 -08:00