Mathew Odden
9ff891dfa1
[ROCm] Remove broken legacy env vars
...
These env vars are no longer used or need and were
being set incorrectly.
[ROCm] Use specific amdgpu version for EL8 systems
We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.
[ROCm] Fix formatting on python files
Reformatted with black
2024-09-30 12:39:51 -05:00
Enrique Piqueras
aa73aa0021
Pallas pipeline API tweaks for more advanced pipelining patterns.
...
PiperOrigin-RevId: 678426679
2024-09-24 15:52:11 -07:00
Peter Hawkins
85a466d730
Lower the shard count for sparse_bcoo_bcsr_test on TPU as well.
...
There are flaky timeouts in CI, and we've already lowered the shard count on multiple other platforms.
PiperOrigin-RevId: 678367575
2024-09-24 13:10:32 -07:00
Peter Hawkins
70f91db853
Set PYTHONWARNINGS=error in bazel tests.
...
The goal of this change is to catch PRs that introduce new warnings sooner.
To help pass the environment variable more easily, rename the jax_test Bazel test macro to jax_multiplatform_test, and introduce a new jax_py_test macro that wraps py_test. Add code to both to set the environment variable.
Add code to suppress some new warnings uncovered in CI.
PiperOrigin-RevId: 678352286
2024-09-24 12:30:11 -07:00
jax authors
d58a09faed
Update XLA dependency to use revision
...
bcc98dcd1c
.
PiperOrigin-RevId: 678338581
2024-09-24 11:53:00 -07:00
Parker Schuh
5e3f7618fc
Support pmin and pmax in check_rep.
...
PiperOrigin-RevId: 678336530
2024-09-24 11:46:30 -07:00
jax authors
6e116491c1
Add --use_cuda_nvcc
flag to enable or disable compilation of CUDA code using NVCC.
...
If `--use_cuda_nvcc` flag is set the NVCC compiler driver will be used to build the CUDA code (default behavior). Otherwise, if the flag `--nouse_cuda_nvcc` is set, only the clang compiler will be used to build the CUDA code (effectively disabling NVCC).
Mark `--use_clang` flag as deprecated.
Refactor `.bazelrc` configs to match the new flag and to cleanup all previous confusing names.
PiperOrigin-RevId: 678332548
2024-09-24 11:37:00 -07:00
Jevin Jiang
407dc774f7
[Mosaic TPU] Support all cases for extui.
...
PiperOrigin-RevId: 678331795
2024-09-24 11:35:03 -07:00
Chris Jones
be7fe878c3
[pallas:triton] Elide program_id
calls where launch grid dimension is 1.
...
This may allow for parts of indexing calculations to be optimized away.
PiperOrigin-RevId: 678321871
2024-09-24 11:11:43 -07:00
Dougal Maclaurin
d2ac88c193
Expose some APIs for querying trace state. This will let us move users away from
...
depending on our internals. Prep work for "stackless".
PiperOrigin-RevId: 678288660
2024-09-24 09:48:41 -07:00
Adam Paszke
9114b084fc
[Pallas] Update export compatibility tests
...
The old test was generated before our IR was really stable, which has started
to cause problems when trying to test with Trillium.
PiperOrigin-RevId: 678277755
2024-09-24 09:17:56 -07:00
jax authors
2c85465ebe
Merge pull request #23806 from gspschmid:gschmid/ffi-ext-bundle
...
PiperOrigin-RevId: 678273475
2024-09-24 09:05:20 -07:00
jax authors
d9a10e13a6
Merge pull request #23873 from hawkinsp:cumsum
...
PiperOrigin-RevId: 678272211
2024-09-24 09:01:56 -07:00
jax authors
26ba306383
Merge pull request #23860 from jakevdp:better-error
...
PiperOrigin-RevId: 678271248
2024-09-24 08:58:56 -07:00
Chris Jones
8d86a04727
[pallas] Allow TransformedRef
to be passed to pl.load
/ pl.store
, when idx = ()
.
...
PiperOrigin-RevId: 678257485
2024-09-24 08:17:21 -07:00
Peter Hawkins
562e9e8dff
Fix an incorrect output for jnp.cumsum.
...
If dtype=bool but a non-bool input is passed, we should test for
non-equality with zero rather than performing a cast to integer.
2024-09-24 14:46:44 +00:00
jax authors
6860617ebf
Merge pull request #23744 from jakevdp:unary-ufunc
...
PiperOrigin-RevId: 678235132
2024-09-24 07:11:43 -07:00
Adam Paszke
ae86ef16c7
[Mosaic GPU] Add support for input_output_aliases
...
PiperOrigin-RevId: 678217775
2024-09-24 06:13:28 -07:00
Jake VanderPlas
6229511f6a
Make jnp.negative a ufunc & add unary ufunc tests
2024-09-24 05:23:27 -07:00
Jake VanderPlas
a44e129ae7
Add more informative error when static argument is passed to non-static JIT parameter
2024-09-24 05:22:18 -07:00
Sergei Lebedev
8196c8bf36
Added support for % and select
to mgpu.FragmentedArray
...
PiperOrigin-RevId: 678200940
2024-09-24 05:19:25 -07:00
jax authors
6d35113686
Merge pull request #23861 from hawkinsp:warnings
...
PiperOrigin-RevId: 678026551
2024-09-23 18:58:27 -07:00
Peter Hawkins
a0e4448393
Remove warning filters from pyproject.toml, add local warning
...
suppressions.
We want to support running Bazel tests with PYTHONWARNINGS=error. In
preparation for that change, move warning suppressions from
pyproject.toml into the individual test cases that generate them, which
is a reasonable cleanup anyway.
2024-09-24 01:38:24 +00:00
jax authors
80cb821a79
Merge pull request #23862 from jakevdp:array-api-ci
...
PiperOrigin-RevId: 678019472
2024-09-23 18:32:08 -07:00
Jake VanderPlas
45af8742ca
trigger array API tests for all PRs.
...
We should have done this when we deprecated jax.experimental.array_api
2024-09-23 18:18:15 -07:00
jax authors
2ac1d0b8d0
Merge pull request #23741 from awshaichen:neuron
...
PiperOrigin-RevId: 678006682
2024-09-23 17:41:11 -07:00
Yash Katariya
a99ea73336
Use jax.make_array_from_process_local_data
API in distributed data loading doc
...
PiperOrigin-RevId: 677973689
2024-09-23 16:03:34 -07:00
Dongseong Hwang
e4091a6752
Fix another errata in block-sparse kernel tutorial.
...
PiperOrigin-RevId: 677952796
2024-09-23 15:04:29 -07:00
jax authors
1b3d8dc451
Merge pull request #23807 from kaixih:fix_cudnn_sdpa_bwd_batcher
...
PiperOrigin-RevId: 677937290
2024-09-23 14:20:30 -07:00
jax authors
46867dc495
Merge pull request #23739 from jakevdp:add-at
...
PiperOrigin-RevId: 677924623
2024-09-23 13:47:27 -07:00
jax authors
b0aeae4918
Merge pull request #23801 from jakevdp:meshgrid-doc
...
PiperOrigin-RevId: 677919342
2024-09-23 13:32:48 -07:00
jax authors
ae2c5958e0
Merge pull request #23852 from ROCm:ci_typename
...
PiperOrigin-RevId: 677916371
2024-09-23 13:24:27 -07:00
Jake VanderPlas
cc885ff875
Better docs for jnp.meshgrid
2024-09-23 12:43:14 -07:00
Ruturaj4
29a1cb766e
[ROCM] add missing typename keyword to work with gcc
2024-09-23 14:42:01 -05:00
jax authors
4fccd64c8b
Update XLA dependency to use revision
...
1162b7e30d
.
PiperOrigin-RevId: 677897482
2024-09-23 12:31:44 -07:00
jax authors
dc1ace5992
Re-enable tsan tests after fix.
...
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
Chris Jones
712e638ca4
[pallas] Add support for unblocked
mode (without padding) in Triton lowering.
...
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574
[Pallas] Simplify sign and erf_inv tests
...
Removed the method to locally enabling x64 using:
```python
with contextlib.ExitStack() as stack:
if jnp.dtype(dtype).itemsize == 8:
stack.enter_context(config.enable_x64(True))
```
This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.
PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09
[pallas:mosaic_gpu] Basic implementation of wgmma.
...
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
kaixih
d29a757e30
fix bwd batcher for unsupported dbias
2024-09-23 17:43:25 +00:00
jax authors
8362ab7490
Merge pull request #23837 from rajasekharporeddy:testbranch3
...
PiperOrigin-RevId: 677843854
2024-09-23 10:16:27 -07:00
jax authors
63a890f2d8
Merge pull request #23834 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 677843049
2024-09-23 10:14:41 -07:00
Jake VanderPlas
3134ece9b7
ufuncs: improve jnp.add.at & jnp.multiply.at
2024-09-23 09:15:58 -07:00
Dongseong Hwang
91f16419bb
Fix errata in block-sparse kernel tutorial.
...
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.
PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
rajasekharporeddy
e976dee4de
Improve docs for jax.numpy: square, sqrt and modf
2024-09-23 21:10:26 +05:30
jax authors
c05706b7a9
Merge pull request #23816 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 677807429
2024-09-23 08:37:15 -07:00
jax authors
6c52ddc97f
[Checkify] Add checks for shard_map.
...
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
rajasekharporeddy
41eccd925d
Improve docs for jnp.logspace and jnp.geomspace
2024-09-23 20:09:12 +05:30
Sergei Lebedev
1256e18fd4
Added comparison operators to mgpu.FragmentedArray
...
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
rajasekharporeddy
6a72c52292
Improve docs for jax.numpy: conjugate, conj, imag and real
2024-09-23 19:40:09 +05:30