23181 Commits

Author SHA1 Message Date
Mathew Odden
9ff891dfa1 [ROCm] Remove broken legacy env vars
These env vars are no longer used or need and were
being set incorrectly.

[ROCm] Use specific amdgpu version for EL8 systems

We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.

[ROCm] Fix formatting on python files

Reformatted with black
2024-09-30 12:39:51 -05:00
Enrique Piqueras
aa73aa0021 Pallas pipeline API tweaks for more advanced pipelining patterns.
PiperOrigin-RevId: 678426679
2024-09-24 15:52:11 -07:00
Peter Hawkins
85a466d730 Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.

PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853 Set PYTHONWARNINGS=error in bazel tests.
The goal of this change is to catch PRs that introduce new warnings sooner.

To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.

Add code to suppress some new warnings uncovered in CI.

PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
d58a09faed Update XLA dependency to use revision
bcc98dcd1c.

PiperOrigin-RevId: 678338581
2024-09-24 11:53:00 -07:00
Parker Schuh
5e3f7618fc Support pmin and pmax in check_rep.
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
jax authors
6e116491c1 Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC.
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).

Mark `--use_clang` flag as deprecated.

Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.

PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
Jevin Jiang
407dc774f7 [Mosaic TPU] Support all cases for extui.
PiperOrigin-RevId: 678331795
2024-09-24 11:35:03 -07:00
Chris Jones
be7fe878c3 [pallas:triton] Elide program_id calls where launch grid dimension is 1.
This may allow for parts of indexing calculations to be optimized away.

PiperOrigin-RevId: 678321871
2024-09-24 11:11:43 -07:00
Dougal Maclaurin
d2ac88c193 Expose some APIs for querying trace state. This will let us move users away from
depending on our internals. Prep work for "stackless".

PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Adam Paszke
9114b084fc [Pallas] Update export compatibility tests
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.

PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
jax authors
2c85465ebe Merge pull request #23806 from gspschmid:gschmid/ffi-ext-bundle
PiperOrigin-RevId: 678273475
2024-09-24 09:05:20 -07:00
jax authors
d9a10e13a6 Merge pull request #23873 from hawkinsp:cumsum
PiperOrigin-RevId: 678272211
2024-09-24 09:01:56 -07:00
jax authors
26ba306383 Merge pull request #23860 from jakevdp:better-error
PiperOrigin-RevId: 678271248
2024-09-24 08:58:56 -07:00
Chris Jones
8d86a04727 [pallas] Allow TransformedRef to be passed to pl.load / pl.store, when idx = ().
PiperOrigin-RevId: 678257485
2024-09-24 08:17:21 -07:00
Peter Hawkins
562e9e8dff Fix an incorrect output for jnp.cumsum.
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
jax authors
6860617ebf Merge pull request #23744 from jakevdp:unary-ufunc
PiperOrigin-RevId: 678235132
2024-09-24 07:11:43 -07:00
Adam Paszke
ae86ef16c7 [Mosaic GPU] Add support for input_output_aliases
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Jake VanderPlas
6229511f6a Make jnp.negative a ufunc & add unary ufunc tests 2024-09-24 05:23:27 -07:00
Jake VanderPlas
a44e129ae7 Add more informative error when static argument is passed to non-static JIT parameter 2024-09-24 05:22:18 -07:00
Sergei Lebedev
8196c8bf36 Added support for % and select to mgpu.FragmentedArray
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
jax authors
6d35113686 Merge pull request #23861 from hawkinsp:warnings
PiperOrigin-RevId: 678026551
2024-09-23 18:58:27 -07:00
Peter Hawkins
a0e4448393 Remove warning filters from pyproject.toml, add local warning
suppressions.

We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
jax authors
80cb821a79 Merge pull request #23862 from jakevdp:array-api-ci
PiperOrigin-RevId: 678019472
2024-09-23 18:32:08 -07:00
Jake VanderPlas
45af8742ca trigger array API tests for all PRs.
We should have done this when we deprecated jax.experimental.array_api
2024-09-23 18:18:15 -07:00
jax authors
2ac1d0b8d0 Merge pull request #23741 from awshaichen:neuron
PiperOrigin-RevId: 678006682
2024-09-23 17:41:11 -07:00
Yash Katariya
a99ea73336 Use jax.make_array_from_process_local_data API in distributed data loading doc
PiperOrigin-RevId: 677973689
2024-09-23 16:03:34 -07:00
Dongseong Hwang
e4091a6752 Fix another errata in block-sparse kernel tutorial.
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
jax authors
1b3d8dc451 Merge pull request #23807 from kaixih:fix_cudnn_sdpa_bwd_batcher
PiperOrigin-RevId: 677937290
2024-09-23 14:20:30 -07:00
jax authors
46867dc495 Merge pull request #23739 from jakevdp:add-at
PiperOrigin-RevId: 677924623
2024-09-23 13:47:27 -07:00
jax authors
b0aeae4918 Merge pull request #23801 from jakevdp:meshgrid-doc
PiperOrigin-RevId: 677919342
2024-09-23 13:32:48 -07:00
jax authors
ae2c5958e0 Merge pull request #23852 from ROCm:ci_typename
PiperOrigin-RevId: 677916371
2024-09-23 13:24:27 -07:00
Jake VanderPlas
cc885ff875 Better docs for jnp.meshgrid 2024-09-23 12:43:14 -07:00
Ruturaj4
29a1cb766e [ROCM] add missing typename keyword to work with gcc 2024-09-23 14:42:01 -05:00
jax authors
4fccd64c8b Update XLA dependency to use revision
1162b7e30d.

PiperOrigin-RevId: 677897482
2024-09-23 12:31:44 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
Chris Jones
712e638ca4 [pallas] Add support for unblocked mode (without padding) in Triton lowering.
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574 [Pallas] Simplify sign and erf_inv tests
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09 [pallas:mosaic_gpu] Basic implementation of wgmma.
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
kaixih
d29a757e30 fix bwd batcher for unsupported dbias 2024-09-23 17:43:25 +00:00
jax authors
8362ab7490 Merge pull request #23837 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 677843854
2024-09-23 10:16:27 -07:00
jax authors
63a890f2d8 Merge pull request #23834 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 677843049
2024-09-23 10:14:41 -07:00
Jake VanderPlas
3134ece9b7 ufuncs: improve jnp.add.at & jnp.multiply.at 2024-09-23 09:15:58 -07:00
Dongseong Hwang
91f16419bb Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.

PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
rajasekharporeddy
e976dee4de Improve docs for jax.numpy: square, sqrt and modf 2024-09-23 21:10:26 +05:30
jax authors
c05706b7a9 Merge pull request #23816 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 677807429
2024-09-23 08:37:15 -07:00
jax authors
6c52ddc97f [Checkify] Add checks for shard_map.
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
rajasekharporeddy
41eccd925d Improve docs for jnp.logspace and jnp.geomspace 2024-09-23 20:09:12 +05:30
Sergei Lebedev
1256e18fd4 Added comparison operators to mgpu.FragmentedArray
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30