22260 Commits

Author SHA1 Message Date
Jake VanderPlas
a9690adf7b jnp.where: explicitly use mode='drop' when out-of-bound indices expected 2024-07-31 14:55:56 -07:00
jax authors
faea216701 Merge pull request #22797 from jakevdp:array-index-error
PiperOrigin-RevId: 658150131
2024-07-31 14:50:24 -07:00
jax authors
057a79d134 Update XLA dependency to use revision
a3bc38ac79.

PiperOrigin-RevId: 658142918
2024-07-31 14:28:33 -07:00
Jake VanderPlas
9a88ecb244 Improve error when indexing with too many indices 2024-07-31 13:57:48 -07:00
jax authors
7d8b8578b5 Merge pull request #22477 from kaixih:support_gqa
PiperOrigin-RevId: 658130108
2024-07-31 13:50:49 -07:00
Parker Schuh
d7d9724e14 If the product of manual axes is of size 1, then skip emitting
any shard_to_full or full_to_shard ops.

PiperOrigin-RevId: 658116164
2024-07-31 13:12:06 -07:00
Jake VanderPlas
c24f20968b [array api] use jnp.astype directly 2024-07-31 11:52:46 -07:00
jax authors
6870d37822 Added test cases for more TPU Mosaic bugs.
PiperOrigin-RevId: 658067355
2024-07-31 10:58:27 -07:00
jax authors
6d34498a9f Merge pull request #22792 from mattjj:vmap-shmap-error-suppression
PiperOrigin-RevId: 658066170
2024-07-31 10:54:49 -07:00
Yash Katariya
e3fc05ad5b Rename mock_num_processes to mock_num_gpu_processes since this flag is only for GPUs. The naming change was a regression introduced in https://github.com/google/jax/pull/22619
PiperOrigin-RevId: 658061107
2024-07-31 10:41:03 -07:00
Matthew Johnson
abfb1ce72d add temporary flag to suppress an error message, to unblock a user 2024-07-31 17:23:47 +00:00
jax authors
d696813b1f Merge pull request #22746 from gnecula:pallas_consts
PiperOrigin-RevId: 658050734
2024-07-31 10:13:34 -07:00
Peter Hawkins
858dc54590 Fix or disable some tests that fail when using a Eigen BLAS with AVX vectorization.
PiperOrigin-RevId: 658047868
2024-07-31 10:06:45 -07:00
jax authors
c7eb023746 Merge pull request #22735 from dfm:custom-vjp-remat-opt
PiperOrigin-RevId: 658043956
2024-07-31 09:56:29 -07:00
jax authors
c0b986caaf Merge pull request #22744 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 658042543
2024-07-31 09:52:10 -07:00
Bart Chrzaszcz
4cf1fbe4cc Hide the SDY dialect right before MLIR->HLO conversion in the XLA pipeline.
Since Shardy is inside the middle of the XLA pipeline, after converting down to HLO, we need to run the Shardy export pipeline to preserve the SDY ops and sharding attributes for when we come back from HLO to MLIR when Shardy propagation is run.

PiperOrigin-RevId: 658040672
2024-07-31 09:45:43 -07:00
Justin Fu
3c7c9ffbbc [Pallas] Correctly handle asymmetrical remote DMA dst_ref indexing in interpret mode.
PiperOrigin-RevId: 658029297
2024-07-31 09:10:53 -07:00
Justin Fu
acacbe297f [Pallas] Move distributed TPU tests to their own file.
PiperOrigin-RevId: 658014149
2024-07-31 08:19:19 -07:00
George Necula
65450d165e Remove forward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th.
Also add backwards compatibility test for the new custom call.

PiperOrigin-RevId: 658011228
2024-07-31 08:10:17 -07:00
Dan Foreman-Mackey
30d5a78b1c Add optional automatic remat optimization to custom_vjp.
As reported in https://github.com/google/jax/issues/21303, using `remat`
with `custom_vjp` can produce inefficient results. The high level
summary is that computing the grad of such a function results in the
`fwd` function of the `custom_vjp` being evaluated twice, even though
the first time the residuals are not actually used. In many cases this
isn't a problem because DCE will clean up the unnecessary computations.
But, when the fwd function requires an opaque call (e.g. pallas_call or
ffi_call), this no longer saves the day.

In this PR, I have added a parameter to `custom_vjp` called
`optimize_remat` (open for discussion!), which can be used to opt-in to
automatic optimization of this operation. Setting this flag to true
results in the `fwd` function being wrapped in a new custom primitive
which will DCE into a call to the primal function whenever the residuals
are unused.

This can be used to fix https://github.com/google/jax/issues/21303, and
I think it would make sense to eventually make this behavior the
default, but this implementation comes with a few caveats:

1. This feature is currently implemented in "initial style", which means
   that the `fwd` function is traced to a jaxpr when it is initially
   called. This means that when `optimize_remat=True`, the `custom_vjp`
   function doesn't support data dependent conditionals within `fwd`.
   This isn't a fundamental limitation of the method, but this
   implementation is much simpler so it seemed like a good place to
   start, and much of the complexity of the "final style" version of
   this logic should be simplified by work that @dougalm is doing.
   Furthermore, for the immediate use case of opaque calls, initial
   style is not a serious limitation.
2. When `optimize_remat=True`, symbolic zeros are not supported. Again
   this isn't a required restriction, but I chose to start without this
   added complexity and we can add support for symbolic zeros as needed
   in the future.
3. More subtly, while this new primitive supports `vmap`, it doesn't
   currently implement rules for composing with the AD system. This
   means that a `custom_vjp` constructed with `optimize_remat=True`
   won't currently work with some approaches to higher-order AD. I
   expect I know how to fix that and will either include that here or in
   a follow-up.
2024-07-31 10:48:29 -04:00
Dan Foreman-Mackey
618754d829 Move some common helper functions from lapack_kernels to ffi_helpers.
There were two helper functions for implementing FFI calls that were included directly alongside jaxlib's CPU kernels that will be useful for the GPU kernels as well. This moves those functions into ffi_helpers so that they are accessible from there too.

PiperOrigin-RevId: 658002501
2024-07-31 07:38:33 -07:00
jax authors
1e5bb9e652 Merge pull request #22781 from pearu:pearu/arcsinh-mpmath-787
PiperOrigin-RevId: 657995523
2024-07-31 07:12:59 -07:00
jax authors
a207fe9b77 Export KeyPath and related types to jax.tree_util
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
2024-07-31 06:41:33 -07:00
Adam Paszke
9dba6eb16a [Mosaic TPU] Add support for 1D windows
PiperOrigin-RevId: 657976726
2024-07-31 05:58:19 -07:00
Adam Paszke
4c13594bdd [Mosaic GPU] Add a missing return in wait_parity
Technically it's not an error to wait twice... but why would you?

PiperOrigin-RevId: 657968887
2024-07-31 05:25:17 -07:00
Adam Paszke
e0415c1865 [Mosaic TPU] Don't fold the accumulator into matmul if it has multiple uses
PiperOrigin-RevId: 657967724
2024-07-31 05:19:52 -07:00
Pearu Peterson
54b0cb86f3 Fix arcsinh accuracy test 2024-07-31 13:23:12 +03:00
jax authors
5c9bb612a7 mesh_utils: allow meshes that do not include device at (0, 0, 0).
This is required to allow the use of subslices: e.g., the two halves
of a TPU slice.  One of them will not include the device at
coordinates (0, 0, 0).

E.g., assume we have a TPU v4 1x2x1 slice.

BEFORE THIS CL, if we call _get_physical_tpu_mesh() (an auxiliary for
the public create_device_mesh()) with

jax_devices=[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]

we get the expected result

[[[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

However, if we call it with

jax_devices=[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]

we get the wrong mesh

[[[None]
  [device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

That's because the code before this CL assumed the the incoming
jax_devices are arranged in a cuboid that starts at (0, 0, 0).  When
working with subslices (e.g., half of a TPU slice) that is not always
the case.

AFTER THIS CL, the second case will return
[[[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

For each dimension from the TPU coordinates, this CL computes the min
/ max; we expect the provided devices to fill the [min, max] interval
(in that dimension).  By requesting this for each dimension, we
request that the set of provided devices constitute a cuboid, but,
unlike before this CL, that cuboid does not need to include (0, 0, 0):
it can be "translated", which allows e.g., both half-slices of a big
slice.

PiperOrigin-RevId: 657902201
2024-07-31 01:11:31 -07:00
George Necula
987bf33e85 [pallas] Disallow capturing of consts by kernel functions.
Previously this was allowed, but until recently (#22550) it was
not working correctly in many cases. Now we disallow const
capturing because it can lead to surprises. Instead, the
kernel function must receive all the arrays it needs as explicit
inputs, with proper block specs.
2024-07-31 09:06:29 +02:00
rajasekharporeddy
3a0e4376cd Fix betabinom.logpmf and binom.logpmf for JAX to emulate SciPy's behavior when k=n=0 2024-07-31 07:58:43 +05:30
Gleb Pobudzey
35ba6f78bb Add /dev/nvidiactl to the list of NVIDIA GPU devices. This is to cover the use case where a subset of GPUs are exposed to Docker/Kubernetes; the container might not necessarily see /dev/nvidia0.
Testing:

Tested on a A100 16 GPU VM with only 8 GPU exposed to the container.

PiperOrigin-RevId: 657801141
2024-07-30 18:25:10 -07:00
jax authors
bd14d6ab64 Merge pull request #22768 from jakevdp:array-api-reshape
PiperOrigin-RevId: 657782900
2024-07-30 17:15:02 -07:00
jax authors
7583cbb438 Merge pull request #22766 from jakevdp:array-api-cleanup
PiperOrigin-RevId: 657775853
2024-07-30 16:50:15 -07:00
jax authors
9c669a74d6 Update XLA dependency to use revision
78418c6a4a.

PiperOrigin-RevId: 657742817
2024-07-30 15:04:30 -07:00
Jake VanderPlas
5198db9fdb jnp.repeat: add copy argument for Array API 2024-07-30 14:07:08 -07:00
Jake VanderPlas
1259322f86 [array api] remove redundant definitions for clip() & hypot() 2024-07-30 13:51:04 -07:00
Jake VanderPlas
8bcd288621 Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.

PiperOrigin-RevId: 657717875
2024-07-30 13:49:37 -07:00
jax authors
b996612865 Merge pull request #22085 from vfdev-5:add-device-kwarg-fftfreq
PiperOrigin-RevId: 657707966
2024-07-30 13:20:26 -07:00
jax authors
b1066ee413 Merge pull request #22764 from jakevdp:array-api-methods
PiperOrigin-RevId: 657697622
2024-07-30 12:50:48 -07:00
Jake VanderPlas
c2f2b0ed28 [array API] move api metadata into jax.numpy namespace 2024-07-30 12:15:24 -07:00
Kanglan Tang
d7c2b49c5a Skip test_concrete_layout_in_shardings on GPU backend.
PiperOrigin-RevId: 657661214
2024-07-30 11:10:29 -07:00
Jake VanderPlas
ff8e8ad2fe revert #22734
Reverts 5ce66dc1aae67a88a8ed72584bdc3f5a7f712507

PiperOrigin-RevId: 657638187
2024-07-30 10:17:34 -07:00
jax authors
256956ad58 Merge pull request #22704 from gnecula:pallas_better_errors
PiperOrigin-RevId: 657571604
2024-07-30 06:39:44 -07:00
Peter Hawkins
c1cd7f9e2d Drop support for mhlo in JAX's public API.
PiperOrigin-RevId: 657551590
2024-07-30 05:29:52 -07:00
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
vfdev-5
bb1fb3ba45 Follow-up to #22736
On adding  device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
2024-07-30 05:39:19 +02:00
jax authors
cc212457d2 Merge pull request #22481 from zhenying-liu:offloading
PiperOrigin-RevId: 657413977
2024-07-29 19:43:35 -07:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Yash Katariya
2106a25977 Finish jax and jaxlib v0.4.31 release
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Bixia Zheng
c81f5cd2fc [xla] Replace debug option xla_use_shardy with execution option
use_shardy_partitioner.

Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.

PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00