22614 Commits

Author SHA1 Message Date
jax authors
550607a45d Merge pull request #23197 from jakevdp:quantile-docs
PiperOrigin-RevId: 667602295
2024-08-26 09:02:31 -07:00
jax authors
e3e0860184 Merge pull request #23228 from froystig:scanlen
PiperOrigin-RevId: 667251807
2024-08-24 22:51:45 -07:00
Roy Frostig
a9b41e9fe7 improve scan error message on non-concrete length argument
Specifically, make it speak concretely about the `length` argument.
2024-08-24 22:30:33 -07:00
jax authors
e9143623e0 Update XLA dependency to use revision
4bfb5c82a4.

PiperOrigin-RevId: 667177243
2024-08-24 14:34:48 -07:00
Justin Fu
7253b9ac8b [Pallas] Fix pallas interpret mode DMA test failures.
PiperOrigin-RevId: 666953373
2024-08-23 16:07:53 -07:00
jax authors
0c505b79b4 Merge pull request #23222 from mattjj:rafi
PiperOrigin-RevId: 666950244
2024-08-23 15:57:28 -07:00
jax authors
a2a351f88b Fix pallas int4->int8 conversion
PiperOrigin-RevId: 666939965
2024-08-23 15:19:02 -07:00
jax authors
6a5ca0bb52 Update XLA dependency to use revision
9738684ff8.

PiperOrigin-RevId: 666937202
2024-08-23 15:10:36 -07:00
Colin Gaffney
276c87eba0 Add a more helpful error message in create_hybrid_device_mesh for missing attribute process_index or `slice_index.
PiperOrigin-RevId: 666928476
2024-08-23 14:42:48 -07:00
Matthew Johnson
670a648b7b add experimental jax.no_tracing context manager 2024-08-23 21:21:55 +00:00
Jake VanderPlas
9090b8a4f9 Better docs for jnp quantile & percentile 2024-08-23 13:38:20 -07:00
jax authors
c6c701e6a7 Merge pull request #23196 from jakevdp:register-deprecations
PiperOrigin-RevId: 666900363
2024-08-23 13:16:09 -07:00
jax authors
20d13abfa0 Update XLA dependency to use revision
b0d313b58e.

PiperOrigin-RevId: 666868666
2024-08-23 11:38:33 -07:00
jax authors
279977c61d Refactor hermetic CUDA flags and update --config=cuda to add CUDA dependencies both for bazel build and bazel test phases.
Add `--@local_config_cuda//cuda:override_include_cuda_libs` to override settings for TF wheel.

Forbid building TF wheel with `--@local_config_cuda//cuda:include_cuda_libs=true`

PiperOrigin-RevId: 666848518
2024-08-23 10:44:32 -07:00
Adam Paszke
be59f6ec47 [Mosaic GPU] Support tiled stores of arrays with fewer columns than swizzling
PiperOrigin-RevId: 666798285
2024-08-23 08:06:25 -07:00
Bart Chrzaszcz
71b7e78916 Add jax_test configs for shardy and enable it for pjit_test.py and fix any tests.
Tests fixed include:

- `test_globally_sharded_key_array_8x4_multi_device`
  - Issue was in `replicate_trailing_dims` where an `xc.OpSharding` was always created. Fixed by creating an equivalent SDY sharding.
- `test_aot_out_info`
  - Issue was there was no mesh since there weren't any NamedShardings. Fixed by not asserting a mesh tuple exists in `lower_jaxpr_to_module` when adding the sdy MeshOp (there won't be any propagation)
- `test_concurrent_pjit`
  - In Shardy if there was a tensor dimension of size 0, we'd emit a verification error if the dimension is sharded on an axes. But if the axis is of size 1, then JAX says this is okay. So have shardy assume the same.
- `test_globally_sharded_key_array_result_8x4_single_device`
  - This tests adds a WSC when no `mesh_shape_tuple` exists (`"sdy.sharding_constraint"(%8) <{sharding = #sdy.sharding<@mesh, [{?}, {?}, {}]>}>`), so we should create a mesh named `mesh` with a single device id in case it doesn't exist.
- `testLowerCostAnalysis`
  - This calls into `mlir_module_to_xla_computation` which calls its own MLIR parsing function in `//third_party/tensorflow/compiler/xla/python/mlir.cc`. Needed to register the SDY dialect in it.
- `testShardingConstraintWithArray`
  - This calls `.compiler_ir(dialect="hlo")` which calls `PyMlirModuleToXlaComputation` which converts the MLIR to HLO, but the Sdy dialect is still inside. Export it before converting it to HLO.

PiperOrigin-RevId: 666777167
2024-08-23 06:51:13 -07:00
Adam Paszke
f54e220430 [Mosaic GPU] Add support for short n dimension in WGMMA
PiperOrigin-RevId: 666766079
2024-08-23 06:08:37 -07:00
Adam Paszke
c76787571b [Mosaic GPU] Expose wait_parity on collective barrier
PiperOrigin-RevId: 666761011
2024-08-23 05:49:06 -07:00
Paweł Paruzel
c430b0c5e3 Activate QR Factorization to XLA's FFI
PiperOrigin-RevId: 666722604
2024-08-23 03:21:43 -07:00
Justin Fu
07767e81a0 [Pallas] Add support for casting to/from unsigned integer types.
PiperOrigin-RevId: 666663406
2024-08-22 23:57:17 -07:00
jax authors
b6f2840f2a Update XLA dependency to use revision
cc369fd42a.

PiperOrigin-RevId: 666518740
2024-08-22 15:45:32 -07:00
Adam Paszke
74d96eeb83 [Pallas TPU] Raise a clear error when trying to load/store to a non-SMEM/non-VMEM buffer
PiperOrigin-RevId: 666506411
2024-08-22 15:13:26 -07:00
Tongfei Guo
5da27432e1 [XLA:SPMD] Check gather/scatter partitioning for index parallel case have the index parallel dimensions matches for operand and indices.
PiperOrigin-RevId: 666469705
2024-08-22 13:32:33 -07:00
Kevin Gleason
d72104de59 Use StableHLO filegroup for python APIs in jaxlib MLIR build.
PiperOrigin-RevId: 666450684
2024-08-22 12:36:39 -07:00
jax authors
b17c66450b Merge pull request #23183 from selamw1:roll_docstring
PiperOrigin-RevId: 666428485
2024-08-22 11:33:07 -07:00
selamw1
ef8532bff5 roll_docstring_added
roll_docstring_added

see_also_doc_fixed

examples_adjusted
2024-08-22 10:51:30 -07:00
Adam Paszke
498ddd50ef [Mosaic TPU] Allow overriding memory space assignment of kernel outputs
PiperOrigin-RevId: 666400770
2024-08-22 10:23:34 -07:00
jax authors
a247058327 Merge pull request #23146 from ROCm:rocm62-fixes
PiperOrigin-RevId: 666394394
2024-08-22 10:12:16 -07:00
jax authors
b9e6eb59be Merge pull request #22516 from kaixih:support_variable_seqlen
PiperOrigin-RevId: 666394369
2024-08-22 10:08:08 -07:00
Jake VanderPlas
2c221f2d5a Register several jax.numpy argument name deprecations 2024-08-22 09:41:53 -07:00
Mathew Odden
5c2ffa893f * Add conditional docker interactive mode
Interactive causes bazel to output more
useful info when running locally.

* Fix issue with rocm el8 repo urls

Work around quirk with rocm version
when it ends with 0

* Fix package name conflict

Ubu22 and higher have a package name conflict
between the debian versions and the AMD provided
versions.

* [ROCm] Use clang env
2024-08-22 10:08:41 -05:00
Adam Paszke
9c3f2dcefc [Mosaic GPU] Make CUDA context part of the hash key + replace kernel id with a SHA256 digest
XLA runtime creates a context per device, so we need to make sure that a kernel is loaded
separately on each device.

PiperOrigin-RevId: 666353098
2024-08-22 08:06:37 -07:00
Dan Foreman-Mackey
b56ed8eedd Port GPU kernel for Householder transformation to FFI.
PiperOrigin-RevId: 666305682
2024-08-22 05:23:09 -07:00
Adam Paszke
0b4f64e002 [Mosaic GPU] Allow tile sizes to exceed dimension size
Otherwise, the dimension size still needs to be a multiple of tiling.

PiperOrigin-RevId: 666298624
2024-08-22 04:59:11 -07:00
Paweł Paruzel
4786930a4c Determine LAPACK workspace during Eigenvalue Kernels runtime
PiperOrigin-RevId: 666285759
2024-08-22 04:09:34 -07:00
Paweł Paruzel
a72d46c549 Ignore LAPACK info parameter for QR Factorization
The assumption is that QR Factorization will never fail from LAPACK's side because all necessary verification is happening right before the call.

PiperOrigin-RevId: 666241215
2024-08-22 01:38:38 -07:00
Krishna Haridasan
3713b966c2 Fix a potential segfault in triton kernel call caching
It is possible that a null pointer is inserted into the cache and not updated with a valid kernel call
in case there is an error later during initialization. This change updates the cache to store either
an error or a valid kernel call.

PiperOrigin-RevId: 666161091
2024-08-21 20:45:35 -07:00
jax authors
810a91968a Update XLA dependency to use revision
6cdbb866c6.

PiperOrigin-RevId: 666067632
2024-08-21 15:52:29 -07:00
Yash Katariya
08fc5c0243 Reverts abd442b12a967f1738691b145c557df5df555dcc
PiperOrigin-RevId: 666026942
2024-08-21 14:11:34 -07:00
Benjamin Kramer
0105254ab1 Unbreak Mosaic after 42944da5ba
PiperOrigin-RevId: 665973530
2024-08-21 11:59:09 -07:00
kaixih
558000df7c Support variable sequence lengths 2024-08-21 18:25:55 +00:00
Justin Fu
ce2306bbc1 [Pallas] Add interpret mode rules for semaphores (local signal, wait, read, DMAs).
PiperOrigin-RevId: 665953666
2024-08-21 11:11:11 -07:00
Yash Katariya
abd442b12a Reverts 1e3c079821f5b4811dff37235f1e776eef1b14e4
PiperOrigin-RevId: 665947283
2024-08-21 10:56:51 -07:00
jax authors
6a2a96c3b8 Merge pull request #23166 from ROCm:ci_bazel_build
PiperOrigin-RevId: 665935164
2024-08-21 10:27:51 -07:00
Zhuo Peng
9d1cc33e39 Relaxed the assertion for is_same_structure in jax2tf.call_tf so that tf_fun may mutate the structure of its input parameters.
PiperOrigin-RevId: 665919824
2024-08-21 09:49:39 -07:00
Adam Paszke
d3fd262c9c [Mosaic GPU] Replace block barriers with warpgroup barriers
Block barriers don't work in warp-specialized kernels.
Also, expose the `when` syntax sugar.

PiperOrigin-RevId: 665916133
2024-08-21 09:40:05 -07:00
jax authors
8c7e798bd2 Fix MSAN use-of-uninitialized-value failure in array_test
PiperOrigin-RevId: 665902448
2024-08-21 09:03:24 -07:00
Adam Paszke
ce3ea109a4 [Mosaic GPU] Add a fast type conversion from s8 vectors to bf16 vectors
Regular conversion instructions have a ridiculously low throughput on Hopper,
so replacing them with some bit tricks yields a much faster implementation.

Co-authored-by: Benjamin Chetioui <bchetioui@google.com>
PiperOrigin-RevId: 665893696
2024-08-21 08:39:24 -07:00
Dan Foreman-Mackey
d49d070f0e Skip shape polymorphism tests that are incompatible with released jaxlib version.
PiperOrigin-RevId: 665893050
2024-08-21 08:35:35 -07:00
Yash Katariya
1e3c079821 Strip primitive params from location info because the amount of metadata included leads to huge HLO size increase and causes compilation cache misses in some other setting too.
PiperOrigin-RevId: 665879688
2024-08-21 07:55:46 -07:00