- Add a topology factory registration to xla_bridge.py.
- Move discovery and registration of plugins to the first time backends() or make_pjrt_topology() is called.
PiperOrigin-RevId: 609544983
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.
For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.
Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.
This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.
PiperOrigin-RevId: 551604974
Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.
PiperOrigin-RevId: 529437419
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).
Use as:
```
topo = jax.topologies.get_attached_topology() // Discovers local devices.
mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```
Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149