20322 Commits

Author SHA1 Message Date
Jake VanderPlas
1ea205be1c softmax: deprecate initial argument & always set to -inf internally 2024-04-10 10:23:21 -07:00
Roy Frostig
65034b3da4 add and populate jax.extend.core.primitives 2024-04-10 09:27:42 -07:00
jax authors
ff12b2ad63 Merge pull request #20685 from olupton:no-git-describe
PiperOrigin-RevId: 623515694
2024-04-10 09:20:46 -07:00
Pearu Peterson
fc04ba983c Workaround mpmath 1.3 bugs in tan and tanh evaluation at infinities 2024-04-10 18:26:07 +03:00
Jake VanderPlas
c7b7b01e63 remove test of deprecated jax.random.shuffle API
PiperOrigin-RevId: 623499655
2024-04-10 08:19:12 -07:00
Olli Lupton
b27b77446f _version_from_git_tree: avoid git describe 2024-04-10 12:47:47 +00:00
Sharad Vikram
abfbb0ae2b Add dynamic grid support to emit_pipeline
PiperOrigin-RevId: 623393190
2024-04-09 23:55:51 -07:00
Yash Katariya
0d8eb45c20 Remove the sharding and layout checks for non-DCE'd arguments during AOT safe call.
This is because the tracing, lowering and compilation caches do not register a miss if sharding/layout of a DCE'd arg changes when it's passed again to a jitted function.

This is not true for avals so that check still exists.

PiperOrigin-RevId: 623375760
2024-04-09 22:12:05 -07:00
jax authors
987e4b4198 Update XLA dependency to use revision
380e14ba78.

PiperOrigin-RevId: 623370886
2024-04-09 21:43:32 -07:00
Sai-Suraj-27
60cd5af67a Made the error messages when raising TypeError better. 2024-04-10 09:27:47 +05:30
Yash Katariya
69bf3b866c Don't do layout checks during compiled safe call on DCE'd args.
PiperOrigin-RevId: 623347380
2024-04-09 19:34:25 -07:00
jax authors
c09a45a1a0 Merge pull request #20673 from Micky774:sparse_test
PiperOrigin-RevId: 623326571
2024-04-09 17:29:38 -07:00
jax authors
c246a97ea8 Merge pull request #20644 from jakevdp:complex-astype
PiperOrigin-RevId: 623323549
2024-04-09 17:15:03 -07:00
Jevin Jiang
763c6ff9cc [Pallas] Fix typo in semaphore_wait error messages.
PiperOrigin-RevId: 623321130
2024-04-09 17:04:10 -07:00
Meekail Zain
a7737ca618 Clean up sparse test run conditions 2024-04-09 23:16:12 +00:00
Jake VanderPlas
e07325a672 Make complex_arr.astype(bool) follow NumPy's semantics 2024-04-09 16:15:59 -07:00
Rebecca Chen
d967c33915 Silence some pytype errors.
PiperOrigin-RevId: 623308993
2024-04-09 16:12:00 -07:00
Yue Sheng
f1ae6232e9 Fix token management for ordered side-effects.
Right now, when there are multiple devices, we shall get a output token from each device, but we only keep the token from `device_0` and replicate it across devices to get input tokens for next function call with ordered side-effects. This is fine on TPU/GPU, as they are essentially executed in sequence. But on CPU, they could run in parallel, so we need to make sure the dependency is set correctly.

PiperOrigin-RevId: 623296894
2024-04-09 15:25:21 -07:00
Henning Becker
9809aa1929 Move CUDA specific functions from asm_compiler to cuda_asm_compiler target
This avoids:
- a forward declaration of `GpuContext`
- the `:asm_compiler_header` header only target

The moved code is unchanged - I just move it from one
file to another and fix up includes and dependencies.

Note that this is adding just another `#ifdef` to the redzone allocator code. I will clean this up in a subsequent change.

PiperOrigin-RevId: 623285804
2024-04-09 14:43:41 -07:00
Sergei Lebedev
a205c9120a pallas_call now has only one way to pass compiler_params=
Previously, it was possible to do

    pallas_call(..., foo=42)

and also

    pallas_call(..., compiler_params=dict(foo=42))

PiperOrigin-RevId: 623277572
2024-04-09 14:23:20 -07:00
Sergei Lebedev
008f87d7a3 The compiler_params= argument of pl.pallas_call on GPU now uses "triton" to refer to Triton-specific parameters, instead of the repetitive "triton_params"
PiperOrigin-RevId: 623275152
2024-04-09 14:13:59 -07:00
jax authors
b865c5b1f9 [Pallas TPU] Convert pattern_match_while_to_fori_loop to return (Jaxpr, str) rather than throw exceptions.
Currently, pattern_match_while_to_fori_loop attempts to convert a while_loop jaxpr into a type of fori_loop which Pallas can lower.
To do so, it validates the conditions which would block the jaxpr from being lowered successfully. Because Pallas presently only supports "fori convertable" loops, this matching code also throws Exceptions when the supported conditions are violated.

In the near future, we aim to have support for more ordinary while loops -- but we still would like to perform this match-and-convert procedure when possible.
To facilitate that, this updates the error handling in pattern_match_while_to_fori_loop to simply return errors when hit, so the calling code can determine if they should be thrown.

PiperOrigin-RevId: 623274837
2024-04-09 14:04:25 -07:00
jax authors
967c38d53d Merge pull request #20666 from curlup:main
PiperOrigin-RevId: 623250005
2024-04-09 12:45:48 -07:00
David Dunleavy
cd2b91c398 Update references to TSL config_settings to their new home in XLA
PiperOrigin-RevId: 623249851
2024-04-09 12:36:10 -07:00
jax authors
afb775c168 Jax persistent compilation cache user guide.
This user guide covers using the cache on local filesystems
and Google Cloud.

PiperOrigin-RevId: 623236335
2024-04-09 11:48:33 -07:00
jax authors
828e60cc03 Update XLA dependency to use revision
7deb1d6a6c.

PiperOrigin-RevId: 623225065
2024-04-09 11:13:51 -07:00
Pavel T
44b47035ae
better unsupported indexing handling in lax_numpy.py 2024-04-09 14:09:35 -04:00
David Dunleavy
d18323f3c4 Move tsl/BUILD, tsl.bzl, and tsl.default.bzl to XLA
PiperOrigin-RevId: 623215553
2024-04-09 10:47:06 -07:00
jax authors
675fdba2d6 Merge pull request #20665 from jakevdp:fix-clip-test
PiperOrigin-RevId: 623213147
2024-04-09 10:37:33 -07:00
jax authors
28b81bef5f [Pallas TPU] Pallas while loop -> fori test.
PiperOrigin-RevId: 623204164
2024-04-09 10:11:06 -07:00
Jake VanderPlas
f6851397ac test: fix testClipStaticBounds complex warning 2024-04-09 09:13:26 -07:00
jax authors
77db7a60ed Merge pull request #20637 from jakevdp:array-api-scalar
PiperOrigin-RevId: 623184036
2024-04-09 09:04:12 -07:00
jax authors
f5cc272615 Merge pull request #20646 from ROCm:rcom-ci-tsl-path-fix
PiperOrigin-RevId: 623129753
2024-04-09 05:09:11 -07:00
jax authors
aad8fe6a60 Merge pull request #20661 from superbobry:main
PiperOrigin-RevId: 623129749
2024-04-09 04:58:48 -07:00
Sergei Lebedev
bad0f1431e Do not run Pallas GPU tests on Windows
Triton does not support Windows, so all bets are off.
2024-04-09 12:44:33 +01:00
Matteo Hessel
0b602c5c4d Add sparse_sigmoid to jax.nn
PiperOrigin-RevId: 623108517
2024-04-09 03:10:04 -07:00
jax authors
4d4151db8e Update XLA dependency to use revision
79eccb4fb4.

PiperOrigin-RevId: 623048502
2024-04-08 21:55:22 -07:00
Yash Katariya
d1b1d0b019 Reverts a1c8207caea8bbc323bbcfb7735768822a59f5ce
PiperOrigin-RevId: 623045488
2024-04-08 21:35:02 -07:00
Yash Katariya
a1c8207cae Add kwargs support to in_shardings argument of jax.jit.
Currently, we only support this case:

* If kwargs are specified, then all in_shardings should be specified as dict matching the kwargs. args and kwargs mixture is not allowed. Either everything are kwargs or args hence in_shardings is a dict or specified positionally.

Example:

```
@partial(jax.jit, in_shardings=dict(y=s2, x=s1))
def f(x, y):
  return x * 2, y * 2

f(x=arr, y=arr2)
```

Fixes https://github.com/google/jax/issues/17400

PiperOrigin-RevId: 623018032
2024-04-08 19:19:56 -07:00
Jake VanderPlas
1b3aea8205 Finalize the deprecation of the arr.device() method
The method has been emitting an DeprecationWarning since JAX v0.4.21, released December 2023. Existing uses can be replaced with `arr.devices()` or `arr.sharding`, depending on the context.

PiperOrigin-RevId: 623015500
2024-04-08 19:04:15 -07:00
Ruturaj4
97bf2d2bb8 [ROCm]: fix tsl path 2024-04-08 19:58:41 -05:00
jax authors
af3dcd2b12 Merge pull request #20649 from jakevdp:fix-spelling
PiperOrigin-RevId: 622967850
2024-04-08 15:27:26 -07:00
Jake VanderPlas
5115b89538 Fix typos in comments 2024-04-08 15:16:39 -07:00
jax authors
7b486f4381 Merge pull request #20643 from carlosgmartin:softmax_initial
PiperOrigin-RevId: 622949028
2024-04-08 14:25:05 -07:00
jax authors
0dd240854a Merge pull request #20648 from jakevdp:ruff-config
PiperOrigin-RevId: 622948780
2024-04-08 14:15:31 -07:00
Jake VanderPlas
d33144e298 CI: avoid deprecated ruff configurations 2024-04-08 14:05:22 -07:00
jax authors
374dd5e55b Make layout on Array a property instead of a cached_property.
PiperOrigin-RevId: 622936539
2024-04-08 13:31:20 -07:00
jax authors
4e889ce83b Merge pull request #20629 from Sai-Suraj-27:prefer_TypeErorr
PiperOrigin-RevId: 622930542
2024-04-08 13:09:22 -07:00
carlosgmartin
9c347b9be1 Let initial=-jnp.inf by default for nn.softmax and nn.log_softmax. 2024-04-08 15:47:29 -04:00
jax authors
6ca46723d7 Merge pull request #20578 from carlosgmartin:logsumexp_where
PiperOrigin-RevId: 622920357
2024-04-08 12:37:38 -07:00