Some GPU tests in //tests/logging_test fail on Linux with NVIDIA driver only when we use hermetic CUDA (CUDA isn't installed on Linux).
Reason: method tsl::Env::Default()->GetExecutablePath()` doesn't work properly with command flag (-c). As result subprocessor couldn't get path to logging_test.py file and convert it to path of runtime where CUDA hermetic libraries are placed.
Solution: Save python program to file in runtime directory then run script from the file.
PiperOrigin-RevId: 738152663
A couple of dummy transform inference rules needed to be added in order to
contend with parts of the lowering that do not use the dialect yet, along with
a transform inference rule for `memref.view`.
PiperOrigin-RevId: 738089782
XLA:GPU recently changed its endianness to little endian to better match LLVM
and the rest of the CUDA ecosystem, so we can lift the earlier restrictions.
PiperOrigin-RevId: 737934373
This reduces the chances of overflowing a 32-bit integer when computing tile indices.
Add unit test to reproduce the overflow with the previous implementation of `blocked_fold_in`.
PiperOrigin-RevId: 737778853
There's no reason why not two custom vmappable types cannot share the same spec_type. However, spec_types was a set, which can cause bugs / exceptions.
Suppose that I register two vmappable data_types sharing the same spec_type, and then unregister one of the two. Then, the spec_type is no longer in the set to support the second data_type. Also, an exception will be raised if I try to unregister the two vmappable types (the second call to spec_types.remove).
When unregistering a data type, instead of removing its spec_type from the set, we regenerate the set from the remaining vmappable types.
PiperOrigin-RevId: 737280270
This is an exact port of the current Python implementation to C++ for speed.
I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.
PiperOrigin-RevId: 737014989
* Account test totals correctly for dashboards
* Add blurb to the dev guide on skipping tests
* Remove extra newline
* Default to 0 if "skipped" isn't found
Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
---------
Co-authored-by: Mathew Odden <1471252+mrodden@users.noreply.github.com>
I did not update `jax.dlpack.SUPPORTED_DTYPES` because neither NumPy nor
TensorFlow currently support importing DLPack arrays with the new dtypes.
PiperOrigin-RevId: 736882904
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.
I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.
PiperOrigin-RevId: 736859418
Two PRs that support this feature have been submitted to stablehlo and openxla.
Now we could do the last step - enable it in JAX.
PiperOrigin-RevId: 736799241
For now, propagate transforms for `wgmma`. We do not handle `transpose` for
either operand yet.
The pass isn't called anywhere yet.
PiperOrigin-RevId: 736758754
Previously it was:
`ValueError: A single NamedSharding spec specification can map every mesh axis to at most one positional dimension, but PartitionSpec('x', 'x') has duplicate entries for x`
Now it is:
`TypeError: dot_general operation with inputs: i64[8@x,2], i64[2,8@x] produces an illegally sharded result: i64[8@x,8@x]`
PiperOrigin-RevId: 736657644