8695 Commits

Author SHA1 Message Date
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
Christos Perivolaropoulos
390b0ba4a6 [pallas::mosaic_gpu] Support for tiled transpose transforms.
For the time being this feature only supports 2D on the GMEM side and 4D after
tiling on the SMEM side.

PiperOrigin-RevId: 678683983
2024-09-25 07:00:09 -07:00
Sergei Lebedev
cdea3d4050 lax.fori_loop now allows scalars in its cary when lowering to Mosaic GPU
PiperOrigin-RevId: 678677508
2024-09-25 06:35:23 -07:00
Dan Foreman-Mackey
bc1e1a0220 Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
2024-09-25 06:17:09 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
Lu Teng
a31f79ce0b Flush stdout buffer before checking. 2024-09-25 10:30:42 +08:00
Yash Katariya
1fe0c5dad5 Fix printing of saved_residual for jit by looking for pjit as the primitive instead of xla_call which was removed 2 years ago
PiperOrigin-RevId: 678479141
2024-09-24 19:01:19 -07:00
jax authors
cfb4e85fcd Merge pull request #23823 from mattjj:simplify-extended-dtype-convert-logic
PiperOrigin-RevId: 678456216
2024-09-24 17:29:32 -07:00
Peter Hawkins
291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
tchatow
520980171f Fix jax.numpy.linalg.inv with shape polymorphism 2024-09-24 12:03:06 -07:00
Parker Schuh
5e3f7618fc Support pmin and pmax in check_rep.
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
Adam Paszke
9114b084fc [Pallas] Update export compatibility tests
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.

PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
jax authors
d9a10e13a6 Merge pull request #23873 from hawkinsp:cumsum
PiperOrigin-RevId: 678272211
2024-09-24 09:01:56 -07:00
jax authors
26ba306383 Merge pull request #23860 from jakevdp:better-error
PiperOrigin-RevId: 678271248
2024-09-24 08:58:56 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
jax authors
6860617ebf Merge pull request #23744 from jakevdp:unary-ufunc
PiperOrigin-RevId: 678235132
2024-09-24 07:11:43 -07:00
Adam Paszke
ae86ef16c7 [Mosaic GPU] Add support for input_output_aliases
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Jake VanderPlas
6229511f6a Make jnp.negative a ufunc & add unary ufunc tests 2024-09-24 05:23:27 -07:00
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Sergei Lebedev
8196c8bf36 Added support for % and select to mgpu.FragmentedArray
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
Jane Liu
adaf54a4bb enable the activation offloading test 2024-09-23 23:07:03 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
jax authors
1b3d8dc451 Merge pull request #23807 from kaixih:fix_cudnn_sdpa_bwd_batcher
PiperOrigin-RevId: 677937290
2024-09-23 14:20:30 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
Chris Jones
712e638ca4 [pallas] Add support for unblocked mode (without padding) in Triton lowering.
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574 [Pallas] Simplify sign and erf_inv tests
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09 [pallas:mosaic_gpu] Basic implementation of wgmma.
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
kaixih
d29a757e30 fix bwd batcher for unsupported dbias 2024-09-23 17:43:25 +00:00
jax authors
c05706b7a9 Merge pull request #23816 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 677807429
2024-09-23 08:37:15 -07:00
jax authors
6c52ddc97f [Checkify] Add checks for shard_map.
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
Sergei Lebedev
1256e18fd4 Added comparison operators to mgpu.FragmentedArray
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30
Sergei Lebedev
f311e81c02 Added is_signed to mgpu.FragmentedArray
The registers within a fragmented array always use signless types, and instead
the signedness is tracked on the fragmented arrays itself (i.e. in Python).

PiperOrigin-RevId: 677776009
2024-09-23 06:59:41 -07:00
jax authors
ba29d5a022 Merge pull request #23821 from jakevdp:jnp-doc-examples
PiperOrigin-RevId: 677770780
2024-09-23 06:41:28 -07:00
Vadym Matsishevskyi
2199685437 Ignore scipy.stats._axis_nan_policy.SmallSampleWarning for LaxBackedScipyStatsTests.testMode
It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array).

Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace.

PiperOrigin-RevId: 677629866
2024-09-22 22:26:33 -07:00
Ayaka
b6fe793909 [Pallas] Skip atomic_cas and atomic_counter tests on GPU in 64-bit mode
These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by https://github.com/jax-ml/jax/pull/23798

PiperOrigin-RevId: 677583606
2024-09-22 18:55:39 -07:00
Christos Perivolaropoulos
48c29f62e1 [pallas:mosaic_gpu] Fragmented array debug printing.
PiperOrigin-RevId: 677537364
2024-09-22 14:30:53 -07:00
jax authors
bceceabae0 Merge pull request #23812 from mattjj:custom-primal-tangent-dtype-helper
PiperOrigin-RevId: 677269012
2024-09-21 13:50:55 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Yash Katariya
a2b39192d2 Make make_array_from_process_local_data go via device_put if there is only 1 process.
PiperOrigin-RevId: 677232996
2024-09-21 10:23:20 -07:00
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
Ayaka
d63afd8438 [Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
2024-09-20 16:18:31 -07:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
jax authors
1b3488001b Merge pull request #23734 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 676941019
2024-09-20 12:55:41 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
Parker Schuh
1acf9567aa Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00