8 Commits

Author SHA1 Message Date
Yash Katariya
ae9d1498e5 Bump minimum jaxlib version to 0.4.11. xla_extension_version is 158 and mlir_api_version is 49. It will subsume https://github.com/google/jax/pull/16161#issuecomment-1564977332
PiperOrigin-RevId: 537047525
2023-06-01 09:42:55 -07:00
Parker Schuh
11b34a90fd Skip stream-executor for aot_test.py.
PiperOrigin-RevId: 531248352
2023-05-11 10:51:32 -07:00
Parker Schuh
261ff9e9ed Stop passing CompileOptions when deserializing.
PiperOrigin-RevId: 531034200
2023-05-10 16:22:54 -07:00
pizzud
40d730be49 aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
2023-05-04 09:53:26 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Parker Schuh
4750ce7c87 Add an experimental API that allows compiling AOT for TPUs.
PiperOrigin-RevId: 525536075
2023-04-19 13:33:59 -07:00
Parker Schuh
97c70f2171 initiate an experimental topologies module
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).

Use as:
```
    topo = jax.topologies.get_attached_topology() // Discovers local devices.
    mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149
2023-04-17 11:54:33 -07:00
Parker Schuh
c2b15a1eb8 Break out aot_test from array_test (for serialization and other aot APIs).
PiperOrigin-RevId: 521568985
2023-04-03 14:47:53 -07:00