initiate an experimental topologies module

Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted).

Use as:
```
    topo = jax.topologies.get_attached_topology() // Discovers local devices.
    mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh.
```

Co-authored-by: Roy Frostig <frostig@google.com>
PiperOrigin-RevId: 524909149
This commit is contained in:
Parker Schuh 2023-04-17 11:53:57 -07:00 committed by jax authors
parent b035a2b61b
commit 97c70f2171
2 changed files with 78 additions and 0 deletions

View File

@ -0,0 +1,45 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import numpy as np
import jax
from jax.experimental import mesh_utils
from jax._src.lib import xla_client as xc
Device = xc.Device
class Topology(abc.ABC):
def __init__(self, devices: List[Device]):
self.devices: List[Device] = devices
def get_attached_topology(platform=None) -> Topology:
return Topology(jax.devices(backend=platform))
# -- future mesh_utils --
def make_mesh(topo: Topology, mesh_shape: Sequence[int],
axis_names: Tuple[str, ...],
*, contiguous_submeshes: bool = False
) -> jax.sharding.Mesh:
devices = mesh_utils.create_device_mesh(
mesh_shape, list(topo.devices), contiguous_submeshes=contiguous_submeshes)
return jax.sharding.Mesh(devices, axis_names)

View File

@ -20,12 +20,15 @@ from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
from jax.experimental import topologies
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec as P
from jax.config import config
@ -97,6 +100,36 @@ class JaxAotTest(jtu.JaxTestCase):
jax.pmap(lambda x: x * x).lower(
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
def test_topology_pjit_serialize(self):
self.check_for_compile_options()
@jax.jit
def fn(x):
return x * x
def lower_and_load(mesh):
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
x_shape = jax.ShapeDtypeStruct(
shape=(16, 16),
dtype=jnp.dtype('float32'),
sharding=s)
lowered = fn.lower(x_shape)
serialized, in_tree, out_tree = serialize(lowered.compile())
compiled = deserialize_and_load(serialized, in_tree, out_tree)
return compiled
topo = topologies.get_attached_topology()
n = max(1, len(topo.devices) // 2)
mesh_shape = (len(topo.devices) // n, n)
mesh = topologies.make_mesh(topo, mesh_shape, ('x', 'y'))
ref_mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh(
mesh_shape, jax.devices()),
('x', 'y'))
self.assertEqual(lower_and_load(ref_mesh).as_text(),
lower_and_load(mesh).as_text())
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())