1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-21 06:16:06 +00:00

69 Commits

Author SHA1 Message Date
Yang Chen
08d81e45d4 Use backend._get_all_devices() to validate devices.
PiperOrigin-RevId: 719367913
2025-01-24 11:09:16 -08:00
Peter Hawkins
b06779b177 Switch to a new thread-safe utility for catching warnings.
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
2025-01-09 11:58:34 -05:00
Peter Hawkins
a9926f0f01 Remove classic HLO lowering rule support from JAX.
(JAX uses StableHLO always, now, with the exception of one use case in jax2tf.)

PiperOrigin-RevId: 683205145
2024-10-07 09:06:20 -07:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Yazhou Zu
341e63b60f add xla_bridge test guard on cloud tpu env
PiperOrigin-RevId: 640269835
2024-06-04 13:45:52 -07:00
Yazhou Zu
91d68b5564 creat jax config api to allow custom pjrt client create option settings. this allows a device platform's pjrt client be aware of the calling (customer) ml framework
PiperOrigin-RevId: 638009713
2024-05-28 13:43:06 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Peter Hawkins
e7eb2075b8 Change xla_bridge_test to expect a bytes FDO profile instead of a string in future jaxlib versions.
Change in preparation for nanobind migration.

PiperOrigin-RevId: 613329489
2024-03-06 13:50:36 -08:00
Jieying Luo
087f99a31c Support mocking number of GPUs in CUDA plugin.
Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update.

The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match.

generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version.

PiperOrigin-RevId: 611610915
2024-02-29 15:15:06 -08:00
Jake VanderPlas
d2b4800723 tests: improve warnings-related tests 2023-11-30 10:35:24 -08:00
jax authors
fc8058a17d Restrict retrieving XLA-AutoFDO profile version to TPU workloads.
XLA-AutoFDO is supported only for TPUs, so requesting the latest
profile version for non-TPU workloads is unnecessary and can delay
the completion of initialization.

Testing: test workload.
PiperOrigin-RevId: 584148686
2023-11-20 15:52:03 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to .
2023-10-12 17:32:15 +01:00
Jieying Luo
b81a3e1fd7 Remove calling configure_library_path during jax import and get libtpu path from libtpu_module.get_library_path().
PiperOrigin-RevId: 572306461
2023-10-10 10:59:37 -07:00
Jieying Luo
cb51e37008 [PJRT C API] Adding Profiler C APIs and related framework changes.
C API changes:
- Profiler C APIs are added in profiler_c_api.h.
- Add a PJRT C API extension for the profiler C APIs in pjrt_c_api_profiler_extension.h.

Framework changes:
- Add a plugin_tracer that calls profiler C APIs.
- Add a pybind method xla_client.profiler.register_plugin_profiler to register plugin_tracer with the plugin's PJRT_Api*.
- Update xla_bridge.register_plugin to call register_plugin_profiler to register profiler for that plugin.

PiperOrigin-RevId: 572027222
2023-10-09 13:36:24 -07:00
Peter Hawkins
83f12c5ab2 Fix CI failures from https://github.com/google/jax/pull/17751 2023-09-26 17:52:37 -04:00
Peter Hawkins
210fab1aae Remove the "No GPU/TPU found" warning.
Instead, add a lightweight test for NVIDIA GPUs and Google TPUs. Warn
only if we suspect either is present but JAX is not using them.
2023-09-26 19:04:34 +00:00
Jieying Luo
c7f60fa6eb [PJRT C API] Implement framework side change for registering a custom call.
- Add a py extension to call the custom call C API.
- Change the implementation of register_custom_call_target to store handlers for the custom call targets and delays the registration until the handler for a xla platform is registered.
- Change register_plugin to load PJRT plugin when register_pluin is called (instead of when a client is created), and let it return the PJRT_Api* loaded.
- Delay calling discover_pjrt_plugins() and register_pjrt_plugin_factories_from_env() until the first time backends() is called.

PiperOrigin-RevId: 568265745
2023-09-25 10:52:29 -07:00
jax authors
c38f67043c Hash serialized CompileOptions for new cache key generation.
The original cache key generation hashes individual fields of
CompileOptions, ExecutableBuildOptions, and DebugOptions. This
is not future proof: when a field is added to any of these
structures, the corresponding hash needs to be added to the
cache key generation. The new cache key generation algorithm
hashes the serialized representation of CompileOptions.

Some DebugOptions do not affect the compilation result;
exclude them from the computation. If additional fields are
identified, they can be added; such additions will reduce
unnecessary cache misses.

Testing: revised unit test.
PiperOrigin-RevId: 561803875
2023-08-31 17:21:57 -07:00
jax authors
d1547ca45b Ensure that CompileOptions serializes deterministically.
CompileOptions has two serialization mechanisms: Py pickle and
SerializeAsString. Neither mechanism serializes deterministically.
Deterministic serialization (also called idempotent serialization
or in-order serialization) ensures that a given structure
serializes to the same string repeatedly. Both these mechanisms
serialize by first generating the proto and then serializing it.
There are three points to note:

. Deterministic serialization will yield the same result
  even if proto map fields are in a different order. Thus
  map({"1": 1, "2": 2}) and map({"2": 2, "1": 1}) will
  serialize the same.

. Deterministic serialization does not yield the same
  result for repeated fields that are out of order. Thus,
  for message Foo { repeated string s = 1; },
  Foo{s: "1", s: "2"} will not result in the same
  serialization as Foo{s: "2", s: "1"}.

. Deterministic serialization applies only in the context
  of a given binary. It does not apply across releases.

Testing: the original serialization code with the new unit
test fails as expected while the revised code does not.
PiperOrigin-RevId: 559492626
2023-08-23 11:34:21 -07:00
Jieying Luo
c7e8b81a74 [PJRT C API] Let framework explicitly check whether a plugin is initialized and initialize the plugin.
Before this change, PJRT_Plugin_Initialize was called in LoadPjrtPlugin, which is only used in dynamic linking case. This change adds a bool and a method to check whether the plugin is initialized. The framework will explicitly check whether a plugin is initialized, and call InitializePjrtPlugin if it is not. This will be apply to both static linking and dynamic linking case.

PiperOrigin-RevId: 557268670
2023-08-15 15:24:14 -07:00
Peter Hawkins
a259df0d76 Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
2023-08-15 06:39:46 -07:00
jax authors
eb076c4c44 Explicitly set AutoFDO profile version in CompileOptions.
Set the AutoFDO profile version specified in --jax_xla_profile_version
if non-zero. Otherwise, expect that there is a function set in
get_latest_profile_version that will return a non-zero profile version
that should be used. If this function is not set or it returns 0,
set -1 instead to indicate that no attempt should be made to retrieve
an AutoFDO profile later on.

Testing: updated unit tests.
PiperOrigin-RevId: 555333728
2023-08-09 18:24:56 -07:00
Peter Hawkins
ca17b6c08f Move functions out of xla.py closer to their users.
Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility.

Remove an unused top_k translation rule as well.

PiperOrigin-RevId: 554946059
2023-08-08 14:40:42 -07:00
Yash Katariya
4ddf6a9a54 Bump minimum_jaxlib_version to 0.4.14. xla_extension_version is 174 and mlir_api_version is 54
PiperOrigin-RevId: 552816893
2023-08-01 08:53:28 -07:00
Tao Wang
6eb3096461 Enable to set fdo_profile through XLA python client.
PiperOrigin-RevId: 547303330
2023-07-11 14:47:39 -07:00
Jieying Luo
21588a30a9 [PJRT C API] Add related C type definitions for key value get/put callback, as well as conversion between C and cpp types.
This is similar to how send/receive callback are implemented.

Update make_c_api_client to take key value get/put callback generated from distributed client, and optiosn of node_id and num_nodes.

PiperOrigin-RevId: 543441403
2023-06-26 08:08:52 -07:00
Peter Hawkins
5d665a5898 Add a warning if a non-allowlisted plugin is used.
This is mostly to set user expectations: if we don't know a plugin passes the JAX test suite, we issue a warning.
2023-06-20 13:12:08 -04:00
Peter Hawkins
cc0bbdee8a Fix xla_bridge_test test failures under Windows. 2023-06-20 09:25:34 -04:00
Peter Hawkins
b12be3c35c Raise a RuntimError if plugin initialization fails, rather than logging at info priority.
We don't want plugin initialization to fail silently.

To work around this, either remove the plugin, or disable it via the JAX_PLATFORMS variable.
2023-06-16 14:21:12 -04:00
Peter Hawkins
a05ffab786 Fix xla_bridge_test.py test failures.
We are splitting the plugins in the enviroment variable using os.pathsep; we should make sure to use that as the separator in the test.
2023-06-16 10:17:07 -04:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Jieying Luo
b35c20ce5d Use xla_extension_version and remove some dead version check in xla_bridge_test.py.
Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
2023-05-31 13:50:07 -07:00
Jieying Luo
9da52e8905 [PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory.
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.

Logics to register a plugin from ENV is not deleted to facilitate development with ENV.

PiperOrigin-RevId: 533280115
2023-05-18 16:13:02 -07:00
jax authors
0037ab6240 [PJRT C API] Check whether the PJRT_Api* for the device type already exists before calling dlopen and dlsym.
PiperOrigin-RevId: 531295150
2023-05-11 13:43:17 -07:00
Jieying Luo
b403c2a083 [PJRT C API] Add parsing PJRT client create options from json file.
PiperOrigin-RevId: 518418760
2023-03-21 16:57:34 -07:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
jax authors
8ff293ab75 Fix xla_bridge_test on TPU
DETAILS:
When run xla_bridge_test on TPU v2-8 it raises the follow error about unknown backend tpu, this change set jax_platforms to be "" to eliminate this error.
```
FAILED tests/xla_bridge_test.py::GetBackendTest::test_backend_init_error - RuntimeError: Unable to initialize backend 'tpu': Unknown backend 'tpu' (set JAX_PLATFORMS='' to automatically choose an available backend)
```

TESTED:
pass unit test on both CPU and TPU
PiperOrigin-RevId: 481758573
2022-10-17 15:44:06 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Tom Hennigan
b308874880 Only register backend factories once when jax_platforms config value is set.
PiperOrigin-RevId: 450005138
2022-05-20 09:46:37 -07:00
Peter Hawkins
68e9e1c26d Consolidate more XLA-lowering logic between jit, pmap, and xmap.
Move remaining functions relating to building XLA HLO IR out of xla_bridge.py and into jax.interpreters.xla.

PiperOrigin-RevId: 413244450
2021-11-30 14:24:33 -08:00
Peter Hawkins
714e19a794 Remove xla_bridge.make_computation_builder().
This is a vestigal wrapper around xla_client.XlaBuilder whose purpose is long gone.

Also rename uses of XlaComputationBuilder to XlaBuilder. XlaComputationBuilder was an older name that is gone in most places.
2021-10-18 13:20:34 -04:00
Skye Wanderman-Milne
adcab940b7 Add --jax_platforms flag to replace --jax_platform_name.
The motivation for this change is to make it possible to avoid
initializing unused backends, which may have undesirable side effects
(e.g. GPU memory allocation, only one process can use a Cloud TPU at a
time). It also provides a new and more flexible mechanism for
configuring the default backend, in order to minimize the number of
configs we have.
2021-09-30 11:07:52 -07:00
Skye Wanderman-Milne
5400bf4e2b Refactor xla_bridge.get_backend().
This is mostly a non-functional change, except to fix a few odd edge cases (see unit tests changes).
2021-09-28 11:47:50 -07:00
Skye Wanderman-Milne
836926c47e Add xla_bridge_test.GetBackendTest
This is in preparation for refactoring the get_backend logic, see the TODOs in the test.
2021-09-28 07:17:08 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Yash Katariya
9120c5e370 Fix the flaky test
PiperOrigin-RevId: 389774706
2021-08-09 18:49:46 -07:00