8482 Commits

Author SHA1 Message Date
Peter Hawkins
291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
Parker Schuh
5e3f7618fc Support pmin and pmax in check_rep.
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
Adam Paszke
9114b084fc [Pallas] Update export compatibility tests
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.

PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
jax authors
d9a10e13a6 Merge pull request #23873 from hawkinsp:cumsum
PiperOrigin-RevId: 678272211
2024-09-24 09:01:56 -07:00
jax authors
26ba306383 Merge pull request #23860 from jakevdp:better-error
PiperOrigin-RevId: 678271248
2024-09-24 08:58:56 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
jax authors
6860617ebf Merge pull request #23744 from jakevdp:unary-ufunc
PiperOrigin-RevId: 678235132
2024-09-24 07:11:43 -07:00
Adam Paszke
ae86ef16c7 [Mosaic GPU] Add support for input_output_aliases
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Jake VanderPlas
6229511f6a Make jnp.negative a ufunc & add unary ufunc tests 2024-09-24 05:23:27 -07:00
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Sergei Lebedev
8196c8bf36 Added support for % and select to mgpu.FragmentedArray
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
jax authors
1b3d8dc451 Merge pull request #23807 from kaixih:fix_cudnn_sdpa_bwd_batcher
PiperOrigin-RevId: 677937290
2024-09-23 14:20:30 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
Chris Jones
712e638ca4 [pallas] Add support for unblocked mode (without padding) in Triton lowering.
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574 [Pallas] Simplify sign and erf_inv tests
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09 [pallas:mosaic_gpu] Basic implementation of wgmma.
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
kaixih
d29a757e30 fix bwd batcher for unsupported dbias 2024-09-23 17:43:25 +00:00
jax authors
c05706b7a9 Merge pull request #23816 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 677807429
2024-09-23 08:37:15 -07:00
jax authors
6c52ddc97f [Checkify] Add checks for shard_map.
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
Sergei Lebedev
1256e18fd4 Added comparison operators to mgpu.FragmentedArray
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30
Sergei Lebedev
f311e81c02 Added is_signed to mgpu.FragmentedArray
The registers within a fragmented array always use signless types, and instead
the signedness is tracked on the fragmented arrays itself (i.e. in Python).

PiperOrigin-RevId: 677776009
2024-09-23 06:59:41 -07:00
jax authors
ba29d5a022 Merge pull request #23821 from jakevdp:jnp-doc-examples
PiperOrigin-RevId: 677770780
2024-09-23 06:41:28 -07:00
Vadym Matsishevskyi
2199685437 Ignore scipy.stats._axis_nan_policy.SmallSampleWarning for LaxBackedScipyStatsTests.testMode
It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array).

Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace.

PiperOrigin-RevId: 677629866
2024-09-22 22:26:33 -07:00
Ayaka
b6fe793909 [Pallas] Skip atomic_cas and atomic_counter tests on GPU in 64-bit mode
These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by https://github.com/jax-ml/jax/pull/23798

PiperOrigin-RevId: 677583606
2024-09-22 18:55:39 -07:00
Christos Perivolaropoulos
48c29f62e1 [pallas:mosaic_gpu] Fragmented array debug printing.
PiperOrigin-RevId: 677537364
2024-09-22 14:30:53 -07:00
jax authors
bceceabae0 Merge pull request #23812 from mattjj:custom-primal-tangent-dtype-helper
PiperOrigin-RevId: 677269012
2024-09-21 13:50:55 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Yash Katariya
a2b39192d2 Make make_array_from_process_local_data go via device_put if there is only 1 process.
PiperOrigin-RevId: 677232996
2024-09-21 10:23:20 -07:00
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
Ayaka
d63afd8438 [Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
2024-09-20 16:18:31 -07:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
jax authors
1b3488001b Merge pull request #23734 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 676941019
2024-09-20 12:55:41 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
Parker Schuh
1acf9567aa Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00
jax authors
82b0e0e0fb Merge pull request #23788 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 676891040
2024-09-20 10:30:10 -07:00
jax authors
629be0b701 Tighten test tolerances after the underlying issue causing nondeterministic results for _nrm2 in Eigen BLAS was fixed in https://gitlab.com/libeigen/eigen/-/merge_requests/1667 -> cl/663346025
PiperOrigin-RevId: 676881791
2024-09-20 10:03:46 -07:00
rajasekharporeddy
0c87a23a26 Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians 2024-09-20 22:22:17 +05:30
Adam Paszke
81b8b4b7b4 [Mosaic GPU] Clean up the module structure
Previously the code was awkwardly split between the `jax.experimental.mosaic.gpu`
and `jax.experimental.mosaic.gpu.dsl` namespaces. I've now merged both so that
all user-visible APIs are accessible from `jax.experimental.mosaic.gpu`.

PiperOrigin-RevId: 676857257
2024-09-20 08:42:13 -07:00
Adam Paszke
99195ead83 [Mosaic TPU] Try reducing sublane tiling to support more vector.shape_casts
In particular, 32-bit values should now support all reshapes that do not modify the
last dimension.

PiperOrigin-RevId: 676855401
2024-09-20 08:36:22 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Sharad Vikram
1db47fd85d [Pallas] Minor cleanup of memory spaces. Also add ANY as a general memory space
PiperOrigin-RevId: 676650904
2024-09-19 19:08:18 -07:00
Yash Katariya
e209abfb2c Improve the coverage of shard map tests for < 8 devices. Due to the skip in SetupModule before this change, we lost a lot of coverage on latest hardware.
PiperOrigin-RevId: 676571965
2024-09-19 14:49:08 -07:00
Yash Katariya
c9bbf71ec6 Cleanup ParsedPartitionSpec and remove CanonicalizedParsedPartitionSpec. Also mark user_spec as private.
PiperOrigin-RevId: 676498946
2024-09-19 11:38:48 -07:00
jax authors
73bbd80b80 Merge pull request #22310 from ayaka14732:ayx/lowering/erf_inv_64
PiperOrigin-RevId: 676491259
2024-09-19 11:20:56 -07:00
Vadym Matsishevskyi
cc927dd322 Ignore RuntimeWarning "invalid value encountered in cast" for LaxBackedNumpyTests.testUniqueEqualNan
This is to fix Mac arm64 pytests on CI. The tests started failing after integrating ml-dtypes-0.5.0. Ignoring warnings is probably Ok, as it is inspired by a similar PR in ml-dtypes repo itself: https://github.com/jax-ml/ml_dtypes/pull/186

PiperOrigin-RevId: 676458202
2024-09-19 10:03:06 -07:00