20467 Commits

Author SHA1 Message Date
Roy Frostig
c18f7916a3 bump shard count for random_lax_test
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
jax authors
5d9b5adb17 Return to the win-2019 image for the windows_ci workflow.
Due to https://github.com/bazelbuild/bazel/issues/18592

PiperOrigin-RevId: 629459786
2024-04-30 10:40:27 -07:00
jax authors
e680c495a2 Try fixing the MSVC path for windows_ci workflow once more.
PiperOrigin-RevId: 629436690
2024-04-30 09:26:27 -07:00
jax authors
eeca8d81b9 Fix example in mosaic tpu dialect layout.h
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Chris Jones
20a8e2a6ec Allow replacing jaxpr debug_info with None.
The existing implementation of `Jaxpr.replace` would ignore the parameter `debug_info=None`.

PiperOrigin-RevId: 629421610
2024-04-30 08:31:39 -07:00
jax authors
1bb8b0fe8c Point the windows_ci workflow to the correct VC directory.
PiperOrigin-RevId: 629418384
2024-04-30 08:18:41 -07:00
Adam Paszke
9bc1449588 [Mosaic GPU] Fix the diagnostic dump infrastructure
PiperOrigin-RevId: 629405315
2024-04-30 07:29:46 -07:00
jax authors
bc0fa9b0e3 Use the windows-2022 image for running the windows_ci workflow.
This should help reduce the number of MSVC-related errors.

PiperOrigin-RevId: 629404922
2024-04-30 07:24:43 -07:00
Adam Paszke
4051ac2a2f [Mosaic GPU] Only call kernel initializer from inside a custom call
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...

Also fix up build_wheel.py to match the rename of the runtime lib.

PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
jax authors
649e0521ff Merge pull request #21001 from hawkinsp:warnings
PiperOrigin-RevId: 629397343
2024-04-30 06:48:19 -07:00
Peter Hawkins
7018e0b085 Fix warnings in CI from compilation_cache_test.
Whether the jitted function __eq__ is cached changes the number of warnings we expect.
2024-04-30 13:40:35 +00:00
Blake Hechtman
5b996f7680 [JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Roy Frostig
69878c4924 remove Threefry GPU kernel
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.

PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
jax authors
5c20751ca1 Update XLA dependency to use revision
1de531de59.

PiperOrigin-RevId: 629263505
2024-04-29 19:54:30 -07:00
jax authors
343e18fcb6 Merge pull request #20985 from jakevdp:reshape-signature
PiperOrigin-RevId: 629199023
2024-04-29 15:10:42 -07:00
jax authors
52a0cfc35a Merge pull request #20875 from jakevdp:no-typehints
PiperOrigin-RevId: 629188985
2024-04-29 14:34:00 -07:00
Jake VanderPlas
55e3be65b7 Remove dependency on sphinx_autodoc_typehints 2024-04-29 14:19:46 -07:00
Peter Hawkins
7fe2f5e1e7 Remove jax_triton as a BUILD dependency of pallas_test.py.
We don't need the external jax_triton package any more to lower Pallas code.

PiperOrigin-RevId: 629163749
2024-04-29 13:12:32 -07:00
Jake VanderPlas
89df08afe0 test: fix reshape signature test for NumPy 2.1 2024-04-29 11:48:10 -07:00
jax authors
b44e9bfe66 Merge pull request #20923 from pearu:pearu/asinh-2
PiperOrigin-RevId: 629130313
2024-04-29 11:31:59 -07:00
jax authors
1b5c49e752 Merge pull request #20822 from jakevdp:scipy-special-doc
PiperOrigin-RevId: 629129623
2024-04-29 11:27:23 -07:00
jax authors
94997bd548 Merge pull request #20799 from jakevdp:scipy-dct-doc
PiperOrigin-RevId: 629126171
2024-04-29 11:18:45 -07:00
jax authors
0e62c4cfcc Merge pull request #20976 from gnecula:export_fix_symtab
PiperOrigin-RevId: 629125592
2024-04-29 11:14:04 -07:00
Justin Fu
5d2e8615af Reverts 7844bac5d220b41253495cacf719f61905f46925
PiperOrigin-RevId: 629123629
2024-04-29 11:13:43 -07:00
jax authors
be179ef7f6 Merge pull request #20982 from jakevdp:issubdtype-cache
PiperOrigin-RevId: 629122960
2024-04-29 11:08:36 -07:00
jax authors
3de7c3f3f2 Merge pull request #20983 from google:dependabot/github_actions/actions/download-artifact-4.1.7
PiperOrigin-RevId: 629122570
2024-04-29 11:03:56 -07:00
Jake VanderPlas
74f1d8897c DOC: add manual documentation to jax.scipy.special functions.
This lets us give more implementation-specific information, and
lets us avoid a needless dependency on scipy.
2024-04-29 10:58:07 -07:00
dependabot[bot]
bc354dfeaf
Bump actions/download-artifact from 4.1.6 to 4.1.7
Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 4.1.6 to 4.1.7.
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](9c19ed7fe5...65a9edc588)

---
updated-dependencies:
- dependency-name: actions/download-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-29 17:16:47 +00:00
George Necula
c6e30de49e Fix mlir.merge_mlir_modules to properly remember the inlined symbols
Previously, the `merge_mlir_modules` renamed the inlined symbols to
ensure they do not clash with the symbols in the destination module.
However, the inlined symbols were not inserted in the symbol table
so a conflict could arise later.
2024-04-29 19:56:41 +03:00
Jake VanderPlas
08d45b43f1 Cache (most) calls to dtypes.issubdtype 2024-04-29 09:49:40 -07:00
jax authors
275f565970 Merge pull request #20957 from jakevdp:transpose-doc
PiperOrigin-RevId: 629093069
2024-04-29 09:37:00 -07:00
Adam Paszke
8fd9c2f160 [Mosaic GPU] Add the flash attention example
PiperOrigin-RevId: 629092401
2024-04-29 09:31:30 -07:00
Jake VanderPlas
75921162ab jax.scipy.fft: manually document functions to avoid scipy import 2024-04-29 09:23:43 -07:00
Jake VanderPlas
ba540ca735 Finalize deprecation of jnp.where keyword arguments
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
dfc17187ab Merge pull request #20977 from hawkinsp:warnings
PiperOrigin-RevId: 629079071
2024-04-29 08:41:02 -07:00
Jake VanderPlas
b55e69fb62 DOC: improve docs of transpose & matrix_transpose 2024-04-29 08:17:50 -07:00
Peter Hawkins
6ae01247f0 Fix pytest failures from compilation cache test.
The names of the functions in the compilation cache tests changed, causing warnings emitted by that test to become errors.
2024-04-29 11:08:07 -04:00
Adam Paszke
32cb7c3f94 [Mosaic GPU] Stop using the MLIR CUDA runtime
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.

PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d92d9394ae Merge pull request #20941 from jakevdp:where-doc
PiperOrigin-RevId: 629039271
2024-04-29 05:39:33 -07:00
Adam Paszke
97628420a3 [Mosaic GPU] Use the profiler to compute approximate matmul TFLOPs
PiperOrigin-RevId: 629036644
2024-04-29 05:26:07 -07:00
Adam Paszke
8741ab2f25 Fix imports in Mosaic GPU examples
PiperOrigin-RevId: 629003217
2024-04-29 02:28:56 -07:00
jax authors
47fdc7b08f Update XLA dependency to use revision
2f5eac7ddd.

PiperOrigin-RevId: 628762449
2024-04-27 20:33:55 -07:00
Yash Katariya
1956ff7d7b Add specialize on jax.jit so that we can delete the duplicate code in jax.make_jaxpr.
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`

PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Sergei Lebedev
06760511b2 Added elementwise_inline_asm to Pallas GPU
The new API generalizes approx_tanh, which was implemented via the
ElementwiseInlineAsmOp from the Triton IR.

PiperOrigin-RevId: 628703905
2024-04-27 12:07:06 -07:00
jax authors
cd6eeea9e3 Update XLA dependency to use revision
05386ac5da.

PiperOrigin-RevId: 628591950
2024-04-26 21:17:26 -07:00
Yash Katariya
755f350910 Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation
PiperOrigin-RevId: 628564679
2024-04-26 18:29:18 -07:00
jax authors
d9b75350b7 Adds rewrite patterns for arith.{cmpi,select} and tensor.splat as sources to a vector.transfer_read op.
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
Justin Fu
0b5f3f85be Remove explicit pallas trace_start/trace_stop primitives. These are now automatically inserted with the usage of jax.named_scope.
PiperOrigin-RevId: 628553413
2024-04-26 17:34:04 -07:00
jax authors
60af2262fd Merge pull request #20959 from dlwh:tensorstore_compute_args
PiperOrigin-RevId: 628521131
2024-04-26 15:09:06 -07:00
jax authors
b2654c08f9 Improve performance of SVD when batched by avoiding several cond() constructs.
This also simplifies the code by not special casing the code for all-zero inputs.

PiperOrigin-RevId: 628518807
2024-04-26 15:00:50 -07:00