Roy Frostig
c18f7916a3
bump shard count for random_lax_test
...
PiperOrigin-RevId: 629495786
2024-04-30 12:35:17 -07:00
jax authors
5d9b5adb17
Return to the win-2019
image for the windows_ci
workflow.
...
Due to https://github.com/bazelbuild/bazel/issues/18592
PiperOrigin-RevId: 629459786
2024-04-30 10:40:27 -07:00
jax authors
e680c495a2
Try fixing the MSVC path for windows_ci
workflow once more.
...
PiperOrigin-RevId: 629436690
2024-04-30 09:26:27 -07:00
jax authors
eeca8d81b9
Fix example in mosaic tpu dialect layout.h
...
PiperOrigin-RevId: 629424833
2024-04-30 08:42:54 -07:00
Chris Jones
20a8e2a6ec
Allow replacing jaxpr debug_info
with None
.
...
The existing implementation of `Jaxpr.replace` would ignore the parameter `debug_info=None`.
PiperOrigin-RevId: 629421610
2024-04-30 08:31:39 -07:00
jax authors
1bb8b0fe8c
Point the windows_ci
workflow to the correct VC directory.
...
PiperOrigin-RevId: 629418384
2024-04-30 08:18:41 -07:00
Adam Paszke
9bc1449588
[Mosaic GPU] Fix the diagnostic dump infrastructure
...
PiperOrigin-RevId: 629405315
2024-04-30 07:29:46 -07:00
jax authors
bc0fa9b0e3
Use the windows-2022
image for running the windows_ci
workflow.
...
This should help reduce the number of MSVC-related errors.
PiperOrigin-RevId: 629404922
2024-04-30 07:24:43 -07:00
Adam Paszke
4051ac2a2f
[Mosaic GPU] Only call kernel initializer from inside a custom call
...
XLA:GPU custom call design is far from ideal, as there's apparently no way to figure
out the CUDA context that will be used to run an HLO module before the custom call is
first called. So, we can't preload the kernel onto the GPU, or else we'll get invalid
handle errors due to the load and launch happening in different CUDA contexts...
Also fix up build_wheel.py to match the rename of the runtime lib.
PiperOrigin-RevId: 629401858
2024-04-30 07:10:05 -07:00
jax authors
649e0521ff
Merge pull request #21001 from hawkinsp:warnings
...
PiperOrigin-RevId: 629397343
2024-04-30 06:48:19 -07:00
Peter Hawkins
7018e0b085
Fix warnings in CI from compilation_cache_test.
...
Whether the jitted function __eq__ is cached changes the number of warnings we expect.
2024-04-30 13:40:35 +00:00
Blake Hechtman
5b996f7680
[JAX:MOSAIC] Support transposes that are smaller than the transpose unit and infer native layout to avoid unsupported relayouts.
...
PiperOrigin-RevId: 629289267
2024-04-29 22:03:32 -07:00
Roy Frostig
69878c4924
remove Threefry GPU kernel
...
Cursory timing of `jit(lambda key: random.bits(key, (8, 128 * 128)))` suggests that this is a slight compile-time efficiency loss, taking roughly ~1.25x the time to compile compared to the removed kernel-based lowering. This seems worth the memory improvement, and one kernel fewer to maintain.
PiperOrigin-RevId: 629282330
2024-04-29 21:29:38 -07:00
jax authors
5c20751ca1
Update XLA dependency to use revision
...
1de531de59
.
PiperOrigin-RevId: 629263505
2024-04-29 19:54:30 -07:00
jax authors
343e18fcb6
Merge pull request #20985 from jakevdp:reshape-signature
...
PiperOrigin-RevId: 629199023
2024-04-29 15:10:42 -07:00
jax authors
52a0cfc35a
Merge pull request #20875 from jakevdp:no-typehints
...
PiperOrigin-RevId: 629188985
2024-04-29 14:34:00 -07:00
Jake VanderPlas
55e3be65b7
Remove dependency on sphinx_autodoc_typehints
2024-04-29 14:19:46 -07:00
Peter Hawkins
7fe2f5e1e7
Remove jax_triton as a BUILD dependency of pallas_test.py.
...
We don't need the external jax_triton package any more to lower Pallas code.
PiperOrigin-RevId: 629163749
2024-04-29 13:12:32 -07:00
Jake VanderPlas
89df08afe0
test: fix reshape signature test for NumPy 2.1
2024-04-29 11:48:10 -07:00
jax authors
b44e9bfe66
Merge pull request #20923 from pearu:pearu/asinh-2
...
PiperOrigin-RevId: 629130313
2024-04-29 11:31:59 -07:00
jax authors
1b5c49e752
Merge pull request #20822 from jakevdp:scipy-special-doc
...
PiperOrigin-RevId: 629129623
2024-04-29 11:27:23 -07:00
jax authors
94997bd548
Merge pull request #20799 from jakevdp:scipy-dct-doc
...
PiperOrigin-RevId: 629126171
2024-04-29 11:18:45 -07:00
jax authors
0e62c4cfcc
Merge pull request #20976 from gnecula:export_fix_symtab
...
PiperOrigin-RevId: 629125592
2024-04-29 11:14:04 -07:00
Justin Fu
5d2e8615af
Reverts 7844bac5d220b41253495cacf719f61905f46925
...
PiperOrigin-RevId: 629123629
2024-04-29 11:13:43 -07:00
jax authors
be179ef7f6
Merge pull request #20982 from jakevdp:issubdtype-cache
...
PiperOrigin-RevId: 629122960
2024-04-29 11:08:36 -07:00
jax authors
3de7c3f3f2
Merge pull request #20983 from google:dependabot/github_actions/actions/download-artifact-4.1.7
...
PiperOrigin-RevId: 629122570
2024-04-29 11:03:56 -07:00
Jake VanderPlas
74f1d8897c
DOC: add manual documentation to jax.scipy.special functions.
...
This lets us give more implementation-specific information, and
lets us avoid a needless dependency on scipy.
2024-04-29 10:58:07 -07:00
dependabot[bot]
bc354dfeaf
Bump actions/download-artifact from 4.1.6 to 4.1.7
...
Bumps [actions/download-artifact](https://github.com/actions/download-artifact ) from 4.1.6 to 4.1.7.
- [Release notes](https://github.com/actions/download-artifact/releases )
- [Commits](9c19ed7fe5...65a9edc588
)
---
updated-dependencies:
- dependency-name: actions/download-artifact
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-04-29 17:16:47 +00:00
George Necula
c6e30de49e
Fix mlir.merge_mlir_modules to properly remember the inlined symbols
...
Previously, the `merge_mlir_modules` renamed the inlined symbols to
ensure they do not clash with the symbols in the destination module.
However, the inlined symbols were not inserted in the symbol table
so a conflict could arise later.
2024-04-29 19:56:41 +03:00
Jake VanderPlas
08d45b43f1
Cache (most) calls to dtypes.issubdtype
2024-04-29 09:49:40 -07:00
jax authors
275f565970
Merge pull request #20957 from jakevdp:transpose-doc
...
PiperOrigin-RevId: 629093069
2024-04-29 09:37:00 -07:00
Adam Paszke
8fd9c2f160
[Mosaic GPU] Add the flash attention example
...
PiperOrigin-RevId: 629092401
2024-04-29 09:31:30 -07:00
Jake VanderPlas
75921162ab
jax.scipy.fft: manually document functions to avoid scipy import
2024-04-29 09:23:43 -07:00
Jake VanderPlas
ba540ca735
Finalize deprecation of jnp.where keyword arguments
...
PiperOrigin-RevId: 629086639
2024-04-29 09:10:03 -07:00
jax authors
dfc17187ab
Merge pull request #20977 from hawkinsp:warnings
...
PiperOrigin-RevId: 629079071
2024-04-29 08:41:02 -07:00
Jake VanderPlas
b55e69fb62
DOC: improve docs of transpose & matrix_transpose
2024-04-29 08:17:50 -07:00
Peter Hawkins
6ae01247f0
Fix pytest failures from compilation cache test.
...
The names of the functions in the compilation cache tests changed, causing warnings emitted by that test to become errors.
2024-04-29 11:08:07 -04:00
Adam Paszke
32cb7c3f94
[Mosaic GPU] Stop using the MLIR CUDA runtime
...
This ports the remaining few functions we depended on to the Mosaic GPU runtime.
This has the additional benefit of avoiding the expensive driver calls to determine
maximum SMEM bounds that the MLIR runtime does at every kernel launch.
PiperOrigin-RevId: 629069842
2024-04-29 08:04:51 -07:00
jax authors
d92d9394ae
Merge pull request #20941 from jakevdp:where-doc
...
PiperOrigin-RevId: 629039271
2024-04-29 05:39:33 -07:00
Adam Paszke
97628420a3
[Mosaic GPU] Use the profiler to compute approximate matmul TFLOPs
...
PiperOrigin-RevId: 629036644
2024-04-29 05:26:07 -07:00
Adam Paszke
8741ab2f25
Fix imports in Mosaic GPU examples
...
PiperOrigin-RevId: 629003217
2024-04-29 02:28:56 -07:00
jax authors
47fdc7b08f
Update XLA dependency to use revision
...
2f5eac7ddd
.
PiperOrigin-RevId: 628762449
2024-04-27 20:33:55 -07:00
Yash Katariya
1956ff7d7b
Add specialize
on jax.jit
so that we can delete the duplicate code in jax.make_jaxpr
.
...
You can now do (in addition to make_jaxpr): `jax.jit(f).specialize(*args, **kwargs) -> stages.Specialized`
PiperOrigin-RevId: 628748620
2024-04-27 18:58:16 -07:00
Sergei Lebedev
06760511b2
Added elementwise_inline_asm to Pallas GPU
...
The new API generalizes approx_tanh, which was implemented via the
ElementwiseInlineAsmOp from the Triton IR.
PiperOrigin-RevId: 628703905
2024-04-27 12:07:06 -07:00
jax authors
cd6eeea9e3
Update XLA dependency to use revision
...
05386ac5da
.
PiperOrigin-RevId: 628591950
2024-04-26 21:17:26 -07:00
Yash Katariya
755f350910
Clean up some code in pxla.py that deals with jaxpr and avals. Lift the discharging of refs into a separate function and remove global_in_avals argument from lower_sharding_computation
...
PiperOrigin-RevId: 628564679
2024-04-26 18:29:18 -07:00
jax authors
d9b75350b7
Adds rewrite patterns for arith.{cmpi,select}
and tensor.splat
as sources to a vector.transfer_read op.
...
PiperOrigin-RevId: 628561147
2024-04-26 18:11:18 -07:00
Justin Fu
0b5f3f85be
Remove explicit pallas trace_start/trace_stop primitives. These are now automatically inserted with the usage of jax.named_scope.
...
PiperOrigin-RevId: 628553413
2024-04-26 17:34:04 -07:00
jax authors
60af2262fd
Merge pull request #20959 from dlwh:tensorstore_compute_args
...
PiperOrigin-RevId: 628521131
2024-04-26 15:09:06 -07:00
jax authors
b2654c08f9
Improve performance of SVD when batched by avoiding several cond() constructs.
...
This also simplifies the code by not special casing the code for all-zero inputs.
PiperOrigin-RevId: 628518807
2024-04-26 15:00:50 -07:00