17656 Commits

Author SHA1 Message Date
Jake VanderPlas
9247a62b2b Add CHANGELOG entry for the jnp annotation change 2023-10-02 11:31:28 -07:00
Samuel Agyakwa
228e1925b2 [XLA: Python] Update PJRT plugin configuration mapping with bool type
PiperOrigin-RevId: 570100527
2023-10-02 10:08:34 -07:00
Jake VanderPlas
adba2f0859 Add type stubs for jax.numpy.
This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself.

PiperOrigin-RevId: 570058363
2023-10-02 07:20:20 -07:00
jax authors
0fe420e1ef Merge pull request #17882 from jakevdp:fix-nightly
PiperOrigin-RevId: 570049158
2023-10-02 06:37:02 -07:00
Adam Paszke
c9851ac7f3 [Mosaic] Allow vector.shape_cast to (un)fold the sublane dim, for as long as it remains a multiple of sublane tiling
The old guards were overly restrictive, and we can actually treat a much larger class of reshapes as no-ops.

PiperOrigin-RevId: 570049016
2023-10-02 06:36:45 -07:00
Adam Paszke
77d11e4dfd [Mosaic] Reimplement new relayout routines in C++
PiperOrigin-RevId: 570046852
2023-10-02 06:24:56 -07:00
Peter Hawkins
ae65a7cde5 Disable tests for int4 support on non-TPU platforms.
An upcoming XLA change will reject programs containing int4 on CPU and GPU, because the XLA support is buggy and incomplete. When the XLA supports this we can reenable these tests.

Issue https://github.com/google/jax/issues/17672

PiperOrigin-RevId: 570042917
2023-10-02 06:04:22 -07:00
Jake VanderPlas
568fa7927b Update test to skip np.bitwise_count ufunc
Fixes https://github.com/google/jax/issues/17878
2023-10-02 05:50:07 -07:00
jax authors
4471abe1c6 Merge pull request #17423 from gnecula:export_multi2
PiperOrigin-RevId: 569993648
2023-10-02 02:28:52 -07:00
jax authors
90286f2fe1 Update XLA dependency to use revision
2233d81842.

PiperOrigin-RevId: 569975863
2023-10-02 00:58:05 -07:00
George Necula
4b5ed34cbf [jax_export] Add the next part of multi-platform lowering support.
We change the lowering rule selection code to work when
`ModuleContext.lowering_parameters.platforms` contains multiple
string, and emit conditional
code to select the lowering based on the platform index argument.

These changes will not affect the normal JAX lowering paths (when
`ModuleContext.lowering_parameters.platforms` is `None`). It will
also not affect the JAX native serialization paths for single
platform lowering.

These changes should work for most primitives, with the exception
of the few ones that actually access `ModuleContext.platform` inside
the lowering rules (most primitives just register different
rules for different platforms, which is taken into account by
these changes).

Previous PR in this series: #17316.
2023-10-02 10:57:35 +03:00
jax authors
004e67e1e8 Update XLA dependency to use revision
9938bdbbf3.

PiperOrigin-RevId: 569829166
2023-10-01 01:30:44 -07:00
Enrique Piqueras
0ae2a63426 Add temporary flag for forcing arg tuplization of lowered functions.
PiperOrigin-RevId: 569814851
2023-09-30 23:14:38 -07:00
jax authors
bf46b7427f Update XLA dependency to use revision
f62cc8f27a.

PiperOrigin-RevId: 569686162
2023-09-30 01:01:55 -07:00
Hyeontaek Lim
5379d3ddfb [JAX] Fix a regression in cost_analysis API access for an alternative JAX backend
PiperOrigin-RevId: 569664445
2023-09-29 21:42:33 -07:00
jax authors
095c367c01 Merge pull request #17864 from hawkinsp:buildwheel
PiperOrigin-RevId: 569576169
2023-09-29 13:34:42 -07:00
Peter Hawkins
fa8159681d Clean up build_wheel.py and build_gpu_plugin_wheel.py.
* Use pathlib.Path object-oriented paths.
* Change copy_files() helper to copy many files in one call.
* Make copy_files() also make the output directory, if needed.
* Format file with pyink --pyink-indentation=2
2023-09-29 20:08:42 +00:00
jax authors
4d933ebc58 Merge pull request #17860 from google:test_fix
PiperOrigin-RevId: 569549919
2023-09-29 11:53:02 -07:00
Skye Wanderman-Milne
72b1eb3205 Bump NumpyLinalgTest.testEighRankDeficient tolerance
Otherwise it sometimes fails on Cloud TPU v5e.
2023-09-29 18:43:33 +00:00
Peter Hawkins
ef6fd2ebb6 Bump test tolerance for sqrtm test.
This test fails on ARM with a LAPACK built with gfortran 11.

PiperOrigin-RevId: 569540626
2023-09-29 11:15:48 -07:00
jax authors
59d86b233e Correct typo in dtype ValueError() call.
PiperOrigin-RevId: 569527985
2023-09-29 10:31:58 -07:00
Yash Katariya
a32ed7e002 Bump shard_count for shard_map_test to fix the asan failures
PiperOrigin-RevId: 569520202
2023-09-29 10:02:38 -07:00
Peter Hawkins
e6a62fcd11 [PJRT] Split the GpuId() platform constants into CudaId()/RocmId().
Similarly for the GpuName() constant.

While most of the time we treat CUDA and ROCm GPUs identically, we sometimes want to distinguish between CUDA and ROCm (e.g., for DLPack exports) and it's helpful if this is encoded in the platform ID.

PiperOrigin-RevId: 569513495
2023-09-29 09:35:16 -07:00
Adam Paszke
9f963d2f11 [Pallas:TPU] Use ExtUI to widen booleans to signed integer types.
Otherwise `true` gets converted to `-1`, which is confusing.

PiperOrigin-RevId: 569509184
2023-09-29 09:16:12 -07:00
jax authors
d314cf0954 Merge pull request #17841 from google:tpu_ci_disable_tcmalloc
PiperOrigin-RevId: 569469475
2023-09-29 05:55:48 -07:00
jax authors
88fac56fd0 Update XLA dependency to use revision
99a225c1b9.

PiperOrigin-RevId: 569418541
2023-09-29 01:14:27 -07:00
jax authors
f94bbc18ac Merge pull request #17827 from gnecula:lowering_params
PiperOrigin-RevId: 569392664
2023-09-28 23:08:28 -07:00
George Necula
552fef6fcd Introduce a LoweringParameters dataclass for easier plumbing
There are currently two parameters that are used to configure
lowering: lowering_platform (for cross-platform lowering), and
override_lowering_rules. Each of them are passed as separate arguments
through several layers of lowering internal functions. This is tedious,
and error prone. In fact, override_lowering_rules was not plumbed
in all places, and due to using default arguments in all places,
this leads to silent errors.

We foresee introducing other parameters for lowering: for multi-platform
lowering, for controlling the lowering of effects.

Here is pack all such parameters into a `mlir.LoweringParameters`
dataclass and we plumb that through.
2023-09-29 08:23:05 +03:00
Roy Frostig
3247db774e add tests for host offloading (plus operations) under a custom VJP
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 569333314
2023-09-28 17:19:21 -07:00
Skye Wanderman-Milne
ef241e506e Cloud TPU CI: don't use tcmalloc (temporary workaround for tcmalloc deadlock) 2023-09-28 16:55:11 -07:00
jax authors
e6f8477efe Merge pull request #17826 from superbobry:fix-sort-overloads
PiperOrigin-RevId: 569306390
2023-09-28 15:26:57 -07:00
Peter Hawkins
2eca5b34b3 Add a compile-time version test that verifies CUDA is version 11.8 or newer.
Issue https://github.com/google/jax/issues/17829

PiperOrigin-RevId: 569302585
2023-09-28 15:14:04 -07:00
Chansoo Lee
528b035ee5 Fix typing for kwargs.
PiperOrigin-RevId: 569300602
2023-09-28 15:03:20 -07:00
Junwhan Ahn
8bfe3b92bc Roll back f92a70a41e
Reverts bb4382f0bce074ab081e1e02871e32ba331d1d46

PiperOrigin-RevId: 569292433
2023-09-28 14:32:23 -07:00
Tomás Longeri
4b107f8f7d [Mosaic] apply_vector_layout C++ rewrite (12): func.return
PiperOrigin-RevId: 569268028
2023-09-28 13:07:59 -07:00
Emily Fertig
ac27d287a7 [Mosaic] Add sqrt lowering rule.
PiperOrigin-RevId: 569260464
2023-09-28 12:39:52 -07:00
Tomás Longeri
fc569b44f3 [Mosaic] apply_vector_layout C++ rewrite (11): vector.broadcast
PiperOrigin-RevId: 569246375
2023-09-28 11:47:30 -07:00
jax authors
c490a063c8 Merge pull request #17828 from hawkinsp:tpu
PiperOrigin-RevId: 569210363
2023-09-28 09:47:23 -07:00
Peter Hawkins
d0baa1d11b Fix incorrect backend allowlist in array_interoperability_test.
We intended to only enable this test on CPU and GPU, but we were missing a critical "not".
2023-09-28 10:30:22 -04:00
Adam Paszke
173a270179 [Mosaic] Add retiling swizzles required for int8 matmuls
Ideally we would skip the swizzle entirely, but it is not always possible at the moment.

PiperOrigin-RevId: 569149358
2023-09-28 05:30:20 -07:00
Sergei Lebedev
a8b8267f48 MAINT Reorder the overloads for lax.sort
`Array` is structurally a `Sequence[Array]`, so the first overload always
matches under pytype, which defines `collections.abc.Sequence` as a
`Protocol`.

See
b8f91a37e5/pytype/stubs/builtins/typing.pytd (L149).
2023-09-28 12:51:36 +01:00
Chansoo Lee
79d0a83069 Allow event listners to take extra keyword arguments.
PiperOrigin-RevId: 569138957
2023-09-28 04:38:43 -07:00
jax authors
2d068a1caa Update XLA dependency to use revision
0c71a63a86.

PiperOrigin-RevId: 569108876
2023-09-28 02:12:36 -07:00
Tomás Longeri
a37c292d02 [Mosaic] apply_vector_layout C++ rewrite (10): vector.extract_strided_slice
PiperOrigin-RevId: 569081032
2023-09-27 23:48:26 -07:00
Tomás Longeri
fb90d3ee31 [Mosaic] apply_vector_layout C++ rewrite (9): tpu.repeat
PiperOrigin-RevId: 569078893
2023-09-27 23:36:22 -07:00
Tomás Longeri
b1b81ecc60 [Mosaic] apply_vector_layout C++ rewrite (8): tpu.gather, tpu.iota, tpu.trace
PiperOrigin-RevId: 569069717
2023-09-27 22:52:02 -07:00
Junwhan Ahn
bb4382f0bc Destruct objects owned by WeakRefLRUCache::CacheEntry out of band using GlobalPyRefManager()
This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`.

The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC:

* Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date.
* `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC.

PiperOrigin-RevId: 569062402
2023-09-27 22:15:22 -07:00
Matthew Johnson
a9dc3c1ea3 [shard_map] internal change to shard_map CI testing
PiperOrigin-RevId: 569036873
2023-09-27 20:06:24 -07:00
Peter Hawkins
951298df64 Relax cuDNN version compatibility test to ignore patch versions.
PiperOrigin-RevId: 569020492
2023-09-27 18:40:05 -07:00
jax authors
59360794c1 Merge pull request #17792 from jakevdp:mean-cast-f16
PiperOrigin-RevId: 569019549
2023-09-27 18:30:06 -07:00