19 Commits

Author SHA1 Message Date
George Necula
dd0447a7c6 [aot] Add support for as_text(debug_info=True).
This exposes an easier way to get StableHLO and HLO
with more debugging information (source locations
for StableHLO and metadata for HLO).
2025-01-10 07:59:56 +02:00
Peter Hawkins
7776982a8d Bump xla_extension_version after jaxlib release.
The new minimum version is 301.
2024-12-18 08:07:19 -05:00
Henning Becker
cb6881d9e8 Reverts bdadc53ebcd40a5091d66d2586deba82fe5e01ca
PiperOrigin-RevId: 704758075
2024-12-10 10:25:27 -08:00
Peter Hawkins
bdadc53ebc Disable JaxAotTest.test_topology_pjit_serialize on GPU, which fails in CI.
PiperOrigin-RevId: 702759889
2024-12-04 09:51:23 -08:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jieying Luo
e0bba8ff39 [PJRT C API] Supports plugins to register a method to create topology.
- Add a topology factory registration to xla_bridge.py.
- Move discovery and registration of plugins to the first time backends() or make_pjrt_topology() is called.

PiperOrigin-RevId: 609544983
2024-02-22 16:55:34 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Peter Hawkins
69da839358 Remove test code that checks for the se_tpu runtime.
This runtime no longer exists.

PiperOrigin-RevId: 568242078
2023-09-25 09:30:07 -07:00
Jake Hall
f59a4163fa Test changes for out-of-tree backend. 2023-09-14 12:18:37 +01:00
Parker Schuh
614bbcc626 Add internal jaxlib function for fetching the topology from
a set of devices. We may want to make this topology serializable
or usable as a cache key.

PiperOrigin-RevId: 552931150
2023-08-01 14:54:08 -07:00
Peter Hawkins
76cda0ae07 Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
2023-07-27 12:15:58 -07:00
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Parker Schuh
4750ce7c87 Add an experimental API that allows compiling AOT for TPUs.
PiperOrigin-RevId: 525536075
2023-04-19 13:33:59 -07:00
Parker Schuh
97c70f2171 initiate an experimental topologies module
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).

Use as:
```
    topo = jax.topologies.get_attached_topology() // Discovers local devices.
    mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149
2023-04-17 11:54:33 -07:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00