4 Commits

Author SHA1 Message Date
Skye Wanderman-Milne
bcee442390 Improve TPU v2 and v3 mesh_utils.create_device_mesh logic.
* Fixes a bug when a non-3D mesh was requested
* Adds new logic when requesting a single-host mesh
* Extends logic to v2 as well as v3
2022-03-08 22:47:10 +00:00
Jake VanderPlas
da3aaa1960 Add deprecation warning to JaxTestCase and JaxTestLoader 2022-02-17 14:58:58 -08:00
Skye Wanderman-Milne
17b0866bbe Add contiguous_submeshes option to mesh_utils.create_device_mesh().
Unless you're using GlobalDeviceArrays, the device mesh passed to pjit
must be composed of contiguous submeshes for each process (i.e. each
process's local devices must all be next to each other in the full
mesh and form a rectangular submesh). This change teaches
`create_device_mesh` how to output meshes that satisfy this
constraint in some common cases.

This isn't the default behavior because the resulting meshes are a
little awkward and magical, and eventually we'd like using
GlobalDeviceArrays to be the common use case.
2021-12-10 00:01:12 +00:00
Qiao Zhang
64569abb46 Upstream mesh utils to JAX core.
Co-authored-by: James Bradbury <jekbradbury@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
PiperOrigin-RevId: 415136597
2021-12-08 17:21:58 -08:00