23 Commits

Author SHA1 Message Date
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
Peter Hawkins
8f2f4b45fb Annotate several tests as thread-unsafe.
PiperOrigin-RevId: 714117130
2025-01-10 11:24:39 -08:00
jax authors
231967fdb5 [AutoPGLE] Explicitly ignore host callback pointers
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 700289965
2024-11-26 04:06:15 -08:00
jax authors
bf0150bb22 [JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculating module hash.
PiperOrigin-RevId: 698789020
2024-11-21 08:21:32 -08:00
Keshav
7c660c4ea0 Squashed commit of the following:
commit 1abe9559d1ba7a6ec4e2081c52ebdf0eef6b5e56
Merge: 1e1cc3e07 1b2ba9d1c
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:42:04 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 1e1cc3e0733cca77e2f1ee928f96edcf63f673cf
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 10 09:37:22 2024 -0700

    added comment

commit 631c41fcbdbbac864fadd72c984b07801872f218
Merge: b93b52f27 ce3ea109a
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 21 08:54:00 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit b93b52f27aacf7f58eba914a91810b5d0ac06316
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Aug 20 19:00:08 2024 -0700

    remove stray breakpoint

commit 9ee0842ea98557bcdca0ecfd9031a8ea5274e9a4
Merge: 799e359a5 be53ee10b
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 18:09:19 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 799e359a522acd1a83dd7868a3a9278e189664f6
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Aug 7 17:31:27 2024 -0700

    added tests and minor changes

    fix

commit c973004493f633526b14a6b5acb3afe50d58c977
Merge: 5900969cc b3924da2a
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:28:59 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 5900969cc9178bf3629baa49c6a300446bf6d4a9
Author: Keshav <keshavb@nvidia.com>
Date:   Thu Aug 1 11:20:52 2024 -0700

    minor edits

commit a7cc85a1cb8ddd07b783cc538f25c56f5fb78543
Merge: 89b876270 091eba195
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:17:13 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit 89b876270bf5f16dc10c2f8700d69715752ca184
Author: Keshav <keshavb@nvidia.com>
Date:   Mon Jul 29 14:11:39 2024 -0700

    native IR traversal instead of string manipulation

commit 3b161a414d9579c50e1902047dbd45bac840a767
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 20:12:30 2024 -0700

    longer match string and string search optimization

commit 224ee59d2115ec43000105b97bd6e73c40777ab9
Merge: c7664aa61 6a7822a73
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:08:29 2024 -0700

    Merge remote-tracking branch 'upstream/main' into rm_custom_partitioning_pointer

commit c7664aa61fa9cec55fba9d5ee1d3ffb146a4c2b1
Author: Keshav <keshavb@nvidia.com>
Date:   Sun Jul 28 17:07:04 2024 -0700

    remove custom partitioning ptr from pre-compiled hlo during cache key computation

linter fixes

more linter fixes

more linter fixes

alternate imports
2024-09-10 17:30:08 -07:00
jax authors
0b28a4b168 Strip device_assignment on GPU platform.
This makes the hash invariant on a multi-process case.

PiperOrigin-RevId: 617093247
2024-03-19 01:50:39 -07:00
jax authors
596756f715 Enhance compilation cache key generation with a custom hook.
The custom hook is called every time the cache key is
generated. It can be programmed to add a custom string that
is then hashed as part of the cache key.

Testing: test workloads.
PiperOrigin-RevId: 610586945
2024-02-26 18:18:35 -08:00
Peter Hawkins
aad02dba7e Increase minimum jaxlib version to 0.4.20.
jaxlib 0.4.20 has xla_extension_version 210 and mlir_api_version 54.

PiperOrigin-RevId: 609094229
2024-02-21 12:58:57 -08:00
jax authors
32c99f627e Remove the old cache-key generation code.
We have switched to the new cache-key generation code and
it is stable. Clean up the old code.

Note: since we are still falling back to hashing devices +
platform is the PjRtTopologyDescription serialization has not
been implemented by a backend, we retain those for now.

Testing: test workload.
PiperOrigin-RevId: 590378036
2023-12-12 16:34:32 -08:00
Sharad Vikram
045a9ef1ef Disable all_gather_test on non-v5e TPUs
Also consolidate logic for selectively enabling TPU tests on TPU versions

PiperOrigin-RevId: 588597889
2023-12-06 17:47:17 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
jax authors
2bb2aa1112 Factor LIBTPU_INIT_ARGS into the compilation cache key.
Workloads that set the environment variable LIBTPU_INIT_ARGS
expect that the cache key will be invalidated if the value
of the variable changes between runs. Today, LIBTPU_INIT_ARGS
is not used in the cache key computation. The fix is to factor
it in similar to what is done with the XLA_FLAGS environment
variable.

Testing: new unit test; test workloads.
PiperOrigin-RevId: 582423420
2023-11-14 13:31:08 -08:00
jax authors
db07f40233 Fall-back to original device/backend hashing if topology-desc is unavailable.
The original cache-key generation algorithm hashed devices and backend as
part of generating the key. The new algorithm relies on serialized
PjRtTopologyDescription instead. Not all backends support serialized
PjRtTopologyDescription. Fall back to the original device/backend hashing
if the needed backend does not support it.

Testing: unit testing + test workloads.
PiperOrigin-RevId: 579039803
2023-11-02 18:43:48 -07:00
Skye Wanderman-Milne
58c86064f6 [PJRT:C] Implement PjRtCApiClient::GetTopologyDescription
PiperOrigin-RevId: 577249826
2023-10-27 11:03:04 -07:00
Sergei Lebedev
1079304259 MAINT Do not import the config object in JAX internals
The longer term goal here is to move away from having the config object as
part of the public API and migrate towards module-level functions instead.

Note that we can preserve the dynamic attribute lookup behavior of the
config object via a module-level `__getattr__`
2023-10-18 10:55:13 +01:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Peter Hawkins
69da839358 Remove test code that checks for the se_tpu runtime.
This runtime no longer exists.

PiperOrigin-RevId: 568242078
2023-09-25 09:30:07 -07:00
jax authors
23e4f0b471 Hash serialized topology description for new cache key generation.
The original cache key generation hashes devices and backend. This
is not future proof: it does not work for accelerators other than
TPUs. Change this to use the serialized version of
PjRtTopologyDescription which is supported for all accelerators.

Note:
. CPU and PjRt C API not supported as yet.
. Stream Executor will not be supported.

Testing: revised unit test.
PiperOrigin-RevId: 564461564
2023-09-11 12:08:26 -07:00
jax authors
c38f67043c Hash serialized CompileOptions for new cache key generation.
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.

Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.

Testing: revised unit test.
PiperOrigin-RevId: 561803875
2023-08-31 17:21:57 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
Adam Paszke
0228bf7d3c Fix MSAN errors in cache_key_test
The device_assignment array was never initialized, causing MSAN errors.
Replacing it with np.arange fixes the issue.

PiperOrigin-RevId: 553469463
2023-08-03 07:28:32 -07:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
jax authors
3b28d4e180 Refactor Jax compilation cache key generation.
This is in preparation for introducing a more robust key-generation
algorithm.

This refactoring does not introduce any change in behavior.

Testing: refactored unit tests and test workload.
PiperOrigin-RevId: 551744892
2023-07-27 23:01:00 -07:00