The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.
This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update.
The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match.
generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version.
PiperOrigin-RevId: 611610915
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.
Testing: test workload.
PiperOrigin-RevId: 584148686
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
C API changes:
- Profiler C APIs are added in profiler_c_api.h.
- Add a PJRT C API extension for the profiler C APIs in pjrt_c_api_profiler_extension.h.
Framework changes:
- Add a plugin_tracer that calls profiler C APIs.
- Add a pybind method xla_client.profiler.register_plugin_profiler to register plugin_tracer with the plugin's PJRT_Api*.
- Update xla_bridge.register_plugin to call register_plugin_profiler to register profiler for that plugin.
PiperOrigin-RevId: 572027222
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.
PiperOrigin-RevId: 568265745
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.
Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.
Testing: revised unit test.
PiperOrigin-RevId: 561803875
CompileOptions has two serialization mechanisms: Py pickle and
SerializeAsString. Neither mechanism serializes deterministically.
Deterministic serialization (also called idempotent serialization
or in-order serialization) ensures that a given structure
serializes to the same string repeatedly. Both these mechanisms
serialize by first generating the proto and then serializing it.
There are three points to note:
. Deterministic serialization will yield the same result
even if proto map fields are in a different order. Thus
map({"1": 1, "2": 2}) and map({"2": 2, "1": 1}) will
serialize the same.
. Deterministic serialization does not yield the same
result for repeated fields that are out of order. Thus,
for message Foo { repeated string s = 1; },
Foo{s: "1", s: "2"} will not result in the same
serialization as Foo{s: "2", s: "1"}.
. Deterministic serialization applies only in the context
of a given binary. It does not apply across releases.
Testing: the original serialization code with the new unit
test fails as expected while the revised code does not.
PiperOrigin-RevId: 559492626
Before this change, PJRT_Plugin_Initialize was called in LoadPjrtPlugin, which is only used in dynamic linking case. This change adds a bool and a method to check whether the plugin is initialized. The framework will explicitly check whether a plugin is initialized, and call InitializePjrtPlugin if it is not. This will be apply to both static linking and dynamic linking case.
PiperOrigin-RevId: 557268670
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.
Testing: updated unit tests.
PiperOrigin-RevId: 555333728
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.
Remove an unused top_k translation rule as well.
PiperOrigin-RevId: 554946059
This is similar to how send/receive callback are implemented.
Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.
PiperOrigin-RevId: 543441403
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.
Logics to register a plugin from ENV is not deleted to facilitate development with ENV.
PiperOrigin-RevId: 533280115
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
DETAILS:
When run xla_bridge_test on TPU v2-8 it raises the follow error about unknown backend tpu, this change set jax_platforms to be "" to eliminate this error.
```
FAILED tests/xla_bridge_test.py::GetBackendTest::test_backend_init_error - RuntimeError: Unable to initialize backend 'tpu': Unknown backend 'tpu' (set JAX_PLATFORMS='' to automatically choose an available backend)
```
TESTED:
pass unit test on both CPU and TPU
PiperOrigin-RevId: 481758573
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.
Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
The motivation for this change is to make it possible to avoid
initializing unused backends, which may have undesirable side effects
(e.g. GPU memory allocation, only one process can use a Cloud TPU at a
time). It also provides a new and more flexible mechanism for
configuring the default backend, in order to minimize the number of
configs we have.