23195 Commits

Author SHA1 Message Date
Peter Hawkins
1949413739 Increase sharding of checkify_test on TPU to fix CI flakes.
PiperOrigin-RevId: 678720498
2024-09-25 08:54:29 -07:00
Sergei Lebedev
a373e37be2 Fixed mgpu.FragmentedArray.reduce_sum for integer types
The implementation previously assumed the type is floating and used addf.

PiperOrigin-RevId: 678718871
2024-09-25 08:50:24 -07:00
Peter Hawkins
a43c7f2ace Enable more H100 tests in CI.
Rename "gpu" config CI tag to "gpu_v100".

PiperOrigin-RevId: 678695003
2024-09-25 07:37:48 -07:00
jax authors
9d277e61ce Merge pull request #23409 from dfm:ffi-examples
PiperOrigin-RevId: 678690801
2024-09-25 07:23:26 -07:00
Christos Perivolaropoulos
390b0ba4a6 [pallas::mosaic_gpu] Support for tiled transpose transforms.
For the time being this feature only supports 2D on the GMEM side and 4D after
tiling on the SMEM side.

PiperOrigin-RevId: 678683983
2024-09-25 07:00:09 -07:00
Sergei Lebedev
cdea3d4050 lax.fori_loop now allows scalars in its cary when lowering to Mosaic GPU
PiperOrigin-RevId: 678677508
2024-09-25 06:35:23 -07:00
Dan Foreman-Mackey
bc1e1a0220 Add support for setting a dot product "algorithm" for lax.dot_general.
The StableHLO spec has a new "algorithm" parameter that allows specifying the algorithm that is used to execute a matrix multiplication, and it can tune the trade-off between performance and computational cost. Historically, in JAX, the precision and preferred_element_type parameters have been used to expose some level of control, but their behavior is platform dependent and not sufficiently flexible for performance use cases. This change adds a new "algorithm" parameter to dot_general to add support for the new explicit API.

This parameter can be a member of the `SupportedDotAlgorithm` `Enum` to use an algorithm that is known to be supported on at least some hardware. Otherwise, it can be specified using the `DotAlgorithm` data structure which exposes the full generality of the StableHLO spec.

Transposition is supported using the `transpose_algorithm` argument.

PiperOrigin-RevId: 678672686
2024-09-25 06:17:09 -07:00
Tom Natan
eff00cc449 [JAX] add support for gather/scatter batching dims following the new attributes in stablehlo.
This change also uses the new batching dims for gather/scatter batching rules, to avoid concatenating the indices with iota.

See https://github.com/openxla/stablehlo/pull/2259

PiperOrigin-RevId: 678649138
2024-09-25 04:53:11 -07:00
Yash Katariya
1fe0c5dad5 Fix printing of saved_residual for jit by looking for pjit as the primitive instead of xla_call which was removed 2 years ago
PiperOrigin-RevId: 678479141
2024-09-24 19:01:19 -07:00
jax authors
cfb4e85fcd Merge pull request #23823 from mattjj:simplify-extended-dtype-convert-logic
PiperOrigin-RevId: 678456216
2024-09-24 17:29:32 -07:00
Peter Hawkins
291e52a713 Fix some warnings causing CI failures on ARM.
PiperOrigin-RevId: 678454816
2024-09-24 17:25:26 -07:00
Peter Hawkins
7603901608 [mosaic] Fix a warning causing CI failures.
An array ref object was passed where a dtype was expected.

PiperOrigin-RevId: 678451446
2024-09-24 17:11:52 -07:00
Matthew Johnson
0a73d74a4e simplify conversion logic involving extended dtypes
Previously, the idea was that we would use the `convert_element_type` primitive
to cast to/from extended dtypes. Extended dtype rules specified
`convert_from(dtype1, dtype2) -> bool` and `convert_to(dtype1, dtype2) -> bool`
functions. They were meant to do something like indicate whether a
convert_element_type was legal. But I'm not sure if they really made sense.
The implementation was certainly buggy for non-scalar representation types
(physical element types).

This PR simplifies and fixes things:
1. Instead of overloading the `convert_element_type_p` primitive with more cases
involving casts to/from extended dtypes, let's just have distinct `to_edtype_p`
and `from_edtype_p` primitives, which can be much simpler. We still reuse the
`jax.lax.convert_element_type` API function, so there's no API change to the
few existing users who know about this stuff.
2. Instead of extended dtype rules including `convert_from`/`convert_to`
functions with questionable semantics, let's only allow casts to/from the
representation type, which is already specified by the rules'
`physical_element_aval`. (Indeed that should be roughly _all_ we need, and this
PR is just one step towards realizing that goal.) We still have a boolean
`allow_conversion` on extended dtype rules just so we can handle the PRNGKey
case, where we don't want to allow any casts.
3. Fix the conversion logic to handle non-scalar representation types (physical
element types).
2024-09-25 00:10:01 +00:00
Ayaka
02cfaa858f [Pallas TPU] Improve error message when trying to store a scalar to VMEM
Fixes https://github.com/jax-ml/jax/issues/23884

PiperOrigin-RevId: 678448445
2024-09-24 17:01:41 -07:00
Enrique Piqueras
aa73aa0021 Pallas pipeline API tweaks for more advanced pipelining patterns.
PiperOrigin-RevId: 678426679
2024-09-24 15:52:11 -07:00
Dan Foreman-Mackey
e1a68eee5e Add FFI example project and test on CI.
This PR includes an end-to-end example project which demonstrates the
use of the FFI. This complements [the FFI
tutorial](https://jax.readthedocs.io/en/latest/ffi.html) by putting all
of the code in one place, as well as demonstrating how FFI extensions
can be packaged. Alongside the example project, I have also added a new
GitHub Actions workflow to test the example as part of CI.

For now, the tests only run on CPU, but once we have GPU runners for
GitHub Actions (soon!), I plan on migrating the custom call examples
from `docs/gpu_ops` and `docs/cuda_custom_call` into this test case.

Similarly, I wanted to start small and this example project only
includes exactly the same functions as the tutorial for now, but I think
this could be a good place to showcase more advanced examples (including
custom calls with state).
2024-09-24 17:23:13 -04:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
d58a09faed Update XLA dependency to use revision
bcc98dcd1c.

PiperOrigin-RevId: 678338581
2024-09-24 11:53:00 -07:00
Parker Schuh
5e3f7618fc Support pmin and pmax in check_rep.
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
jax authors
6e116491c1 Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC.
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
Jevin Jiang
407dc774f7 [Mosaic TPU] Support all cases for extui.
PiperOrigin-RevId: 678331795
2024-09-24 11:35:03 -07:00
Chris Jones
be7fe878c3 [pallas:triton] Elide program_id calls where launch grid dimension is 1.
This may allow for parts of indexing calculations to be optimized away.

PiperOrigin-RevId: 678321871
2024-09-24 11:11:43 -07:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Adam Paszke
9114b084fc [Pallas] Update export compatibility tests
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.

PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
jax authors
2c85465ebe Merge pull request #23806 from gspschmid:gschmid/ffi-ext-bundle
PiperOrigin-RevId: 678273475
2024-09-24 09:05:20 -07:00
jax authors
d9a10e13a6 Merge pull request #23873 from hawkinsp:cumsum
PiperOrigin-RevId: 678272211
2024-09-24 09:01:56 -07:00
jax authors
26ba306383 Merge pull request #23860 from jakevdp:better-error
PiperOrigin-RevId: 678271248
2024-09-24 08:58:56 -07:00
Chris Jones
8d86a04727 [pallas] Allow TransformedRef to be passed to pl.load / pl.store, when idx = ().
PiperOrigin-RevId: 678257485
2024-09-24 08:17:21 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
jax authors
6860617ebf Merge pull request #23744 from jakevdp:unary-ufunc
PiperOrigin-RevId: 678235132
2024-09-24 07:11:43 -07:00
Adam Paszke
ae86ef16c7 [Mosaic GPU] Add support for input_output_aliases
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Jake VanderPlas
6229511f6a Make jnp.negative a ufunc & add unary ufunc tests 2024-09-24 05:23:27 -07:00
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Sergei Lebedev
8196c8bf36 Added support for % and select to mgpu.FragmentedArray
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
jax authors
6d35113686 Merge pull request #23861 from hawkinsp:warnings
PiperOrigin-RevId: 678026551
2024-09-23 18:58:27 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
jax authors
80cb821a79 Merge pull request #23862 from jakevdp:array-api-ci
PiperOrigin-RevId: 678019472
2024-09-23 18:32:08 -07:00
Jake VanderPlas
45af8742ca trigger array API tests for all PRs.
We should have done this when we deprecated jax.experimental.array_api
2024-09-23 18:18:15 -07:00
jax authors
2ac1d0b8d0 Merge pull request #23741 from awshaichen:neuron
PiperOrigin-RevId: 678006682
2024-09-23 17:41:11 -07:00
Yash Katariya
a99ea73336 Use jax.make_array_from_process_local_data API in distributed data loading doc
PiperOrigin-RevId: 677973689
2024-09-23 16:03:34 -07:00
Dongseong Hwang
e4091a6752 Fix another errata in block-sparse kernel tutorial.
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
jax authors
1b3d8dc451 Merge pull request #23807 from kaixih:fix_cudnn_sdpa_bwd_batcher
PiperOrigin-RevId: 677937290
2024-09-23 14:20:30 -07:00
jax authors
46867dc495 Merge pull request #23739 from jakevdp:add-at
PiperOrigin-RevId: 677924623
2024-09-23 13:47:27 -07:00
jax authors
b0aeae4918 Merge pull request #23801 from jakevdp:meshgrid-doc
PiperOrigin-RevId: 677919342
2024-09-23 13:32:48 -07:00
jax authors
ae2c5958e0 Merge pull request #23852 from ROCm:ci_typename
PiperOrigin-RevId: 677916371
2024-09-23 13:24:27 -07:00
Jake VanderPlas
cc885ff875 Better docs for jnp.meshgrid 2024-09-23 12:43:14 -07:00
Ruturaj4
29a1cb766e [ROCM] add missing typename keyword to work with gcc 2024-09-23 14:42:01 -05:00
jax authors
4fccd64c8b Update XLA dependency to use revision
1162b7e30d.

PiperOrigin-RevId: 677897482
2024-09-23 12:31:44 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00