23116 Commits

Author SHA1 Message Date
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
Ayaka
d63afd8438 [Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
2024-09-20 16:18:31 -07:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
jax authors
a533635898 Update XLA dependency to use revision
44d14566fc.

PiperOrigin-RevId: 676967851
2024-09-20 14:17:51 -07:00
jax authors
9465d427c0 Merge pull request #22302 from yhtang:add-k8s-initialize
PiperOrigin-RevId: 676962862
2024-09-20 14:03:50 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
jax authors
1b3488001b Merge pull request #23734 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 676941019
2024-09-20 12:55:41 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
Parker Schuh
1acf9567aa Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00
jax authors
82b0e0e0fb Merge pull request #23788 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 676891040
2024-09-20 10:30:10 -07:00
Yu-Hang Tang
c88c3aecae add k8s cluster environment 2024-09-20 17:26:53 +00:00
jax authors
e2cdb796f9 Merge pull request #23802 from hawkinsp:dumps
PiperOrigin-RevId: 676889415
2024-09-20 10:25:25 -07:00
jax authors
419a0c498a Merge pull request #23790 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 676889232
2024-09-20 10:24:10 -07:00
jax authors
629be0b701 Tighten test tolerances after the underlying issue causing nondeterministic results for _nrm2 in Eigen BLAS was fixed in https://gitlab.com/libeigen/eigen/-/merge_requests/1667 -> cl/663346025
PiperOrigin-RevId: 676881791
2024-09-20 10:03:46 -07:00
rajasekharporeddy
0c87a23a26 Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians 2024-09-20 22:22:17 +05:30
rajasekharporeddy
81e50118cf Better doc for jax.numpy.i0 2024-09-20 22:19:31 +05:30
jax authors
886aa944fa Merge pull request #23707 from jakevdp:stop-gradient-doc
PiperOrigin-RevId: 676876785
2024-09-20 09:48:08 -07:00
Peter Hawkins
339db2b433 Format MLIR dump names with leading zeros.
This means the dumps sort in order in a directory listing.
2024-09-20 12:35:23 -04:00
jax authors
0d96f39637 Merge pull request #23383 from jax-ml:dependabot/github_actions/actions/upload-artifact-4.4.0
PiperOrigin-RevId: 676871860
2024-09-20 09:32:12 -07:00
Jake VanderPlas
71450cad56 Add docstrings for jnp.blackman, jnp.bartlett, jnp.hamming, jnp.hanning, jnp.kaiser
Part of https://github.com/jax-ml/jax/issues/21461

PiperOrigin-RevId: 676866721
2024-09-20 09:15:30 -07:00
Adam Paszke
81b8b4b7b4 [Mosaic GPU] Clean up the module structure
Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu`
and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that
all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`.

PiperOrigin-RevId: 676857257
2024-09-20 08:42:13 -07:00
Adam Paszke
99195ead83 [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts
In particular, 32-bit values should now support all reshapes that do not modify the
last dimension.

PiperOrigin-RevId: 676855401
2024-09-20 08:36:22 -07:00
Dan Foreman-Mackey
bc80ecbbe4 Remove forward compatibility checks from cholesky_update lowering.
The forward compatibility window has ended and it should be safe to remove these checks.

PiperOrigin-RevId: 676853740
2024-09-20 08:32:25 -07:00
dependabot[bot]
c4c30e1cfd
Bump actions/upload-artifact from 4.3.6 to 4.4.0
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.6 to 4.4.0.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](834a144ee9...50769540e7)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-20 14:53:30 +00:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Dan Foreman-Mackey
afaa3bf43c Port GPU kernels for SVD to the FFI.
Unlike the other GPU linear algebra kernels that I've ported so far, this one isn't straightforward to implement as a single kernel, and while it does support lowering without access to a GPU (no more descriptor!), it only supports dynamics shapes in the batch dimensions. There are two main technical challenges:

1. The main `gesvd` kernels in cuSolver/hipSolver only support matrices with shape `(m, n)` with `m >= n`. This means that we need to transpose the inputs and outputs as part of the lowering rule when `m < n`. (Note: we actually just use C layouts instead of Fortran layouts to implement this case.) While this could be handled in the kernel, this seemed like a lot of work for somewhat limited benefit, and it would probably have performance implications.

2. The `gesvd` and `gesvdj` kernels return `V^H` and `V` respectively, and the batched version of `gesvdj` doesn't support `full_matrices=False`. This means that we need logic in the lowering rule to handle transposition and slicing. This makes it hard to have the algorithm selection be a parameter to the kernel.

Another note: cuSolver has a 64-bit implementation of the SVD, and we always use that implementation on the CUDA backend. The 32-bit interface is included for ROCM support, and I have tested it manually. This was a feature request from https://github.com/jax-ml/jax/issues/23413.

PiperOrigin-RevId: 676839182
2024-09-20 07:34:50 -07:00
Michael Hudgins
7f3a90c63b Change references in setup.py and utilities to reference the JAX repo move to the JAX-ML org
PiperOrigin-RevId: 676838502
2024-09-20 07:32:15 -07:00
Sharad Vikram
1db47fd85d [Pallas] Minor cleanup of memory spaces. Also add ANY as a general memory space
PiperOrigin-RevId: 676650904
2024-09-19 19:08:18 -07:00
jax authors
bebcc39549 Merge pull request #23779 from mattjj:cusotm-vjp-dont-drop-tangents-when-different-dtype-from-primal
PiperOrigin-RevId: 676620126
2024-09-19 17:13:05 -07:00
Matthew Johnson
7571b9e7f8 custom_vjp: don't drop tangents just because they have a different dtype than the primal
instead, drop them when primal_aval.to_tangent_aval().dtype == float0

TODO: don't do that either. we shouldn't drop the user's output on the floor;
we should require that their rule produce a value of the correct float0 dtype,
or else produce a special symbol that means "zero of whatever type I need" (and
that symbol should probably be a None). but i'm not doing that TODO right now...
2024-09-19 23:31:40 +00:00
Yash Katariya
e209abfb2c Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.
PiperOrigin-RevId: 676571965
2024-09-19 14:49:08 -07:00
Jevin Jiang
47b177bd03 [Mosaic TPU][NFC] Remove FailureOr in getNativeVregOrVmaskTypeImpl
PiperOrigin-RevId: 676566796
2024-09-19 14:35:41 -07:00
jax authors
815dc3ba63 Update XLA dependency to use revision
a0cb798737.

PiperOrigin-RevId: 676545740
2024-09-19 13:40:20 -07:00
Dougal Maclaurin
63e7b7d364 Remove some untested dynamic shapes paths (prep work for stackless).
PiperOrigin-RevId: 676529297
2024-09-19 12:59:48 -07:00
jax authors
5f044a67d8 Merge pull request #23674 from justinjfu:pallas_prefetch_docs
PiperOrigin-RevId: 676525366
2024-09-19 12:49:28 -07:00
jax authors
6c29b0be53 Merge pull request #23769 from hawkinsp:thunks
PiperOrigin-RevId: 676523715
2024-09-19 12:45:19 -07:00
Peter Hawkins
6a3736a1d7 Add a note to the changelog about the new CPU thunks backend, enabled in 0.4.32. 2024-09-19 15:38:52 -04:00
Justin Fu
4bce4f6452 [Pallas] Add block-sparse kernel tutorial 2024-09-19 12:23:03 -07:00
jax authors
cb866aa640 Merge pull request #23761 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 676502378
2024-09-19 11:47:25 -07:00
Yash Katariya
c9bbf71ec6 Cleanup ParsedPartitionSpec and remove CanonicalizedParsedPartitionSpec. Also mark user_spec as private.
PiperOrigin-RevId: 676498946
2024-09-19 11:38:48 -07:00
jax authors
73bbd80b80 Merge pull request #22310 from ayaka14732:ayx/lowering/erf_inv_64
PiperOrigin-RevId: 676491259
2024-09-19 11:20:56 -07:00
Loren Maggiore
f75c5c6b2d [jax] config option to disable using a mesh as a context manager.
PiperOrigin-RevId: 676475039
2024-09-19 10:42:41 -07:00
Peter Hawkins
df781e455a [JAX] Switch host_callback to use MLIR lowering instead of the older direct HLO translation rules.
Change in preparation for removing XlaBuilder from Python bindings.

PiperOrigin-RevId: 676465019
2024-09-19 10:17:17 -07:00
rajasekharporeddy
ef2f2fff06 Improved doc for jnp.vander 2024-09-19 22:42:56 +05:30
Vadym Matsishevskyi
cc927dd322 Ignore RuntimeWarning "invalid value encountered in cast" for LaxBackedNumpyTests.testUniqueEqualNan
This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186

PiperOrigin-RevId: 676458202
2024-09-19 10:03:06 -07:00
Dougal Maclaurin
3b89a2e573 Add a utility function to create a tangent zero value from a primal value.
PiperOrigin-RevId: 676449863
2024-09-19 09:42:12 -07:00
Dan Foreman-Mackey
56d0c695c9 Condition tan lowering on jaxlib version rather than forward compatibility mode.
PiperOrigin-RevId: 676436269
2024-09-19 09:03:51 -07:00
Ayaka
de23fdb5ad [Pallas TPU] Add lowering for 64 bit 2024-09-19 16:42:45 +01:00
Sergei Lebedev
22a7c73d27 Added support for lax.fori_loop in the Pallas Mosaic GPU lowering
This, coupled with `plgpu.async_copy` and barriers, should be enough to sketch
a simple pipelined loop in the kernel.

PiperOrigin-RevId: 676374408
2024-09-19 05:30:45 -07:00
Ayaka
3f23866f75 Enable Pallas ops_test on GPU in 64-bit mode.
Previously, the 64-bit tests are skipped in `PallasBaseTest`, which disables both `OpsTest` and `OpsExtraTest`. This PR enables the 64-bit tests for `OpsTest`, and only disables it for `OpsExtraTest`.

PiperOrigin-RevId: 676373904
2024-09-19 05:29:38 -07:00