23755 Commits

Author SHA1 Message Date
Ruturaj4
29a1cb766e [ROCM] add missing typename keyword to work with gcc 2024-09-23 14:42:01 -05:00
jax authors
4fccd64c8b Update XLA dependency to use revision
1162b7e30d.

PiperOrigin-RevId: 677897482
2024-09-23 12:31:44 -07:00
jax authors
dc1ace5992 Re-enable tsan tests after fix.
PiperOrigin-RevId: 677895934
2024-09-23 12:26:30 -07:00
Chris Jones
712e638ca4 [pallas] Add support for unblocked mode (without padding) in Triton lowering.
PiperOrigin-RevId: 677870258
2024-09-23 11:21:54 -07:00
Ayaka
93203c7574 [Pallas] Simplify sign and erf_inv tests
Removed the method to locally enabling x64 using:

```python
with contextlib.ExitStack() as stack:
  if jnp.dtype(dtype).itemsize == 8:
    stack.enter_context(config.enable_x64(True))
```

This is because we can determine whether a test is running in x64 environment by checking the value of `jax.config.x64_enabled`. There is no need to locally enabling x64.

PiperOrigin-RevId: 677865574
2024-09-23 11:11:09 -07:00
Christos Perivolaropoulos
3e19a28b09 [pallas:mosaic_gpu] Basic implementation of wgmma.
PiperOrigin-RevId: 677864187
2024-09-23 11:06:17 -07:00
kaixih
d29a757e30 fix bwd batcher for unsupported dbias 2024-09-23 17:43:25 +00:00
jax authors
8362ab7490 Merge pull request #23837 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 677843854
2024-09-23 10:16:27 -07:00
jax authors
63a890f2d8 Merge pull request #23834 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 677843049
2024-09-23 10:14:41 -07:00
Jake VanderPlas
3134ece9b7 ufuncs: improve jnp.add.at & jnp.multiply.at 2024-09-23 09:15:58 -07:00
Dongseong Hwang
91f16419bb Fix errata in block-sparse kernel tutorial.
Correct M//blk_M to N//blk_N. It was ok because both values happen to be same.
In addition, grid order is (num_blocks, j) as 'num_blocks' replaces 'i'.

PiperOrigin-RevId: 677817478
2024-09-23 09:07:28 -07:00
rajasekharporeddy
e976dee4de Improve docs for jax.numpy: square, sqrt and modf 2024-09-23 21:10:26 +05:30
jax authors
c05706b7a9 Merge pull request #23816 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 677807429
2024-09-23 08:37:15 -07:00
jax authors
6c52ddc97f [Checkify] Add checks for shard_map.
PiperOrigin-RevId: 677798938
2024-09-23 08:11:22 -07:00
rajasekharporeddy
41eccd925d Improve docs for jnp.logspace and jnp.geomspace 2024-09-23 20:09:12 +05:30
Sergei Lebedev
1256e18fd4 Added comparison operators to mgpu.FragmentedArray
PiperOrigin-RevId: 677788023
2024-09-23 07:37:53 -07:00
rajasekharporeddy
6a72c52292 Improve docs for jax.numpy: conjugate, conj, imag and real 2024-09-23 19:40:09 +05:30
Sergei Lebedev
f311e81c02 Added is_signed to mgpu.FragmentedArray
The registers within a fragmented array always use signless types, and instead
the signedness is tracked on the fragmented arrays itself (i.e. in Python).

PiperOrigin-RevId: 677776009
2024-09-23 06:59:41 -07:00
jax authors
ba29d5a022 Merge pull request #23821 from jakevdp:jnp-doc-examples
PiperOrigin-RevId: 677770780
2024-09-23 06:41:28 -07:00
Sergei Lebedev
653f07a7e1 Updated Pallas Mosaic GPU lowering post Mosaic GPU restructuring
PiperOrigin-RevId: 677758519
2024-09-23 05:58:46 -07:00
Vadym Matsishevskyi
2199685437 Ignore scipy.stats._axis_nan_policy.SmallSampleWarning for LaxBackedScipyStatsTests.testMode
It is to fix our CI, the warning itself started occurring on scipy 1.14 due to this change https://github.com/scipy/scipy/pull/20694, which introduced SmallSampleWarning and started emitting it if the input is an empty array (the `a` variable in the randomized parametrized test LaxBackedScipyStatsTests.testMode sometimes happens to be an empty array).

Note, the actual ignored warning is RungimeWarning (the superclass of SmallSampleWarning) to make it backward compatible (scipy.stats._axis_nan_policy.SmallSampleWarning does not exist in scipy prior 1.14, not to mention it being under private declared in a private (_axis_nan_policy) namespace.

PiperOrigin-RevId: 677629866
2024-09-22 22:26:33 -07:00
Ayaka
b6fe793909 [Pallas] Skip atomic_cas and atomic_counter tests on GPU in 64-bit mode
These tests are failing on GPU in 64-bit mode.

This fixes test failures introduced by https://github.com/jax-ml/jax/pull/23798

PiperOrigin-RevId: 677583606
2024-09-22 18:55:39 -07:00
Christos Perivolaropoulos
48c29f62e1 [pallas:mosaic_gpu] Fragmented array debug printing.
PiperOrigin-RevId: 677537364
2024-09-22 14:30:53 -07:00
jax authors
02994d6bbb Update XLA dependency to use revision
2101ae888f.

PiperOrigin-RevId: 677526024
2024-09-22 13:26:59 -07:00
Frederic Bastien
a159c0f417 Document jax.checkpoint policies. 2024-09-22 16:05:20 -04:00
jax authors
3c3bbb8ab6 Merge pull request #23165 from 8bitmp3:jax-docs-advanced-tutorials
PiperOrigin-RevId: 677507157
2024-09-22 11:37:52 -07:00
jax authors
ba74490e6f Update XLA dependency to use revision
e0eff72204.

PiperOrigin-RevId: 677269125
2024-09-21 13:51:50 -07:00
jax authors
bceceabae0 Merge pull request #23812 from mattjj:custom-primal-tangent-dtype-helper
PiperOrigin-RevId: 677269012
2024-09-21 13:50:55 -07:00
Matthew Johnson
43cc70b7a1 add jax.experimental.primal_tangent_dtype helper
useful for constructing new dtypes which have a distinct tangent type (e.g. for
quantization)
2024-09-21 20:35:20 +00:00
Yash Katariya
a2b39192d2 Make make_array_from_process_local_data go via device_put if there is only 1 process.
PiperOrigin-RevId: 677232996
2024-09-21 10:23:20 -07:00
Jake VanderPlas
aa551e66c5 Test that jax.numpy docstrings include examples 2024-09-21 07:39:17 -07:00
Ayaka
d63afd8438 [Pallas GPU] Enable Pallas OpsExtraTest in 64-bit mode
This is a follow-up of https://github.com/jax-ml/jax/pull/23747, which enables Pallas `OpsTest` in 64-bit mode.

In order to enable Pallas `OpsExtraTest` in 64-bit mode, some of the code in the tests need to be modified. There are three kinds of modifications:

1. Most of the modifications are just changing `jnp.int32` to `intx` and `jnp.float32` to `floatx`, which uses the same approach as the previous PR https://github.com/jax-ml/jax/pull/23747. `intx` and `floatx` are conventions used in Pallas tests to refer to 64-bit types in 64-bit mode and their 32-bit counterparts in 32-bit mode.
2. For the test `test_array_reduce`, the original code uses a simple approach to determine `out_dtype` from `dtype`, which no longer works in 64-bit mode. Therefore, I modified the code to deduct `out_dtype` by executing the operation on a single element first.
3. For the test `test_masked_load_store`, the `idx` variable is expected to be an `int32` array, which is calculated based on `pl.program_id()` and `block_size`. In 64-bit mode, the computation will give out an `int64` array instead. Since `pl.program_id()` always returns an `int32` result, I modified the computation to produce `int32` result. I also modified the `pl.program_id()` docstring to document the behaviour that `pl.program_id()` always returns an `int32` result.

PiperOrigin-RevId: 677007613
2024-09-20 16:18:31 -07:00
8bitmp3
0cf040c9a1 Add/update JAX Advanced Tutorials docs, ToC structure 2024-09-20 23:06:54 +00:00
Jevin Jiang
6b93b35842 [Mosaic:TPU] Efficient relayout with internal scratch
We should support all different retilings (x*packing1, 128) <-> (y*packing2, 128) with any dtype in this cl at this moment. The efficient relayout with scratch brings significant improvements on current retiling in <= TPUv4 and retiling with (packing, 128) in TPUv5. All missing retiling supports are added in this cl, including increase sublane retiling and packed type retiling.

PiperOrigin-RevId: 676982957
2024-09-20 15:00:58 -07:00
jax authors
a533635898 Update XLA dependency to use revision
44d14566fc.

PiperOrigin-RevId: 676967851
2024-09-20 14:17:51 -07:00
jax authors
9465d427c0 Merge pull request #22302 from yhtang:add-k8s-initialize
PiperOrigin-RevId: 676962862
2024-09-20 14:03:50 -07:00
jax authors
ca97af9d43 Change the default implementation of GeLU to a numerically stable formulation.
The old formulation explicitly computed (1 + erf(x/sqrt(2))), which can be extremely inaccurate for negative x due to cancellation.

PiperOrigin-RevId: 676944344
2024-09-20 13:06:31 -07:00
jax authors
1b3488001b Merge pull request #23734 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 676941019
2024-09-20 12:55:41 -07:00
Keshav
8770fb283b set default value to True 2024-09-20 11:48:41 -07:00
rajasekharporeddy
6a5553d6be Improve docs for jax.numpy: remainder, mod and fmod 2024-09-21 00:09:42 +05:30
Parker Schuh
1acf9567aa Add get_replication to shard_map.py for verifying if an array is replicated.
PiperOrigin-RevId: 676910872
2024-09-20 11:25:15 -07:00
kaixih
b7e26ba3ee fix dbias in bwd_batcher 2024-09-20 18:07:55 +00:00
jax authors
82b0e0e0fb Merge pull request #23788 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 676891040
2024-09-20 10:30:10 -07:00
Yu-Hang Tang
c88c3aecae add k8s cluster environment 2024-09-20 17:26:53 +00:00
jax authors
e2cdb796f9 Merge pull request #23802 from hawkinsp:dumps
PiperOrigin-RevId: 676889415
2024-09-20 10:25:25 -07:00
jax authors
419a0c498a Merge pull request #23790 from rajasekharporeddy:testbranch3
PiperOrigin-RevId: 676889232
2024-09-20 10:24:10 -07:00
jax authors
629be0b701 Tighten test tolerances after the underlying issue causing nondeterministic results for _nrm2 in Eigen BLAS was fixed in https://gitlab.com/libeigen/eigen/-/merge_requests/1667 -> cl/663346025
PiperOrigin-RevId: 676881791
2024-09-20 10:03:46 -07:00
rajasekharporeddy
0c87a23a26 Improve docs for jax.numpy: deg2rad, rad2deg, degrees, radians 2024-09-20 22:22:17 +05:30
rajasekharporeddy
81e50118cf Better doc for jax.numpy.i0 2024-09-20 22:19:31 +05:30
jax authors
886aa944fa Merge pull request #23707 from jakevdp:stop-gradient-doc
PiperOrigin-RevId: 676876785
2024-09-20 09:48:08 -07:00