mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
initiate an experimental topologies module
Start off with two functions: one for retrieving the attached topology, and the other for producing a mesh from the topology (modeling how `mesh_utils` might be adapted). Use as: ``` topo = jax.topologies.get_attached_topology() // Discovers local devices. mesh = jax.topologies.make_mesh(topo, mesh_shape, axis_names) # see mesh_utils.create_device_mesh. ``` Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 524909149
This commit is contained in:
parent
b035a2b61b
commit
97c70f2171
45
jax/experimental/topologies.py
Normal file
45
jax/experimental/topologies.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from typing import (Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.experimental import mesh_utils
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
Device = xc.Device
|
||||
|
||||
|
||||
class Topology(abc.ABC):
|
||||
def __init__(self, devices: List[Device]):
|
||||
self.devices: List[Device] = devices
|
||||
|
||||
|
||||
def get_attached_topology(platform=None) -> Topology:
|
||||
return Topology(jax.devices(backend=platform))
|
||||
|
||||
|
||||
# -- future mesh_utils --
|
||||
|
||||
def make_mesh(topo: Topology, mesh_shape: Sequence[int],
|
||||
axis_names: Tuple[str, ...],
|
||||
*, contiguous_submeshes: bool = False
|
||||
) -> jax.sharding.Mesh:
|
||||
devices = mesh_utils.create_device_mesh(
|
||||
mesh_shape, list(topo.devices), contiguous_submeshes=contiguous_submeshes)
|
||||
return jax.sharding.Mesh(devices, axis_names)
|
@ -20,12 +20,15 @@ from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
from jax.experimental import topologies
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jax.config import config
|
||||
@ -97,6 +100,36 @@ class JaxAotTest(jtu.JaxTestCase):
|
||||
jax.pmap(lambda x: x * x).lower(
|
||||
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
|
||||
|
||||
def test_topology_pjit_serialize(self):
|
||||
self.check_for_compile_options()
|
||||
|
||||
@jax.jit
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
def lower_and_load(mesh):
|
||||
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
|
||||
x_shape = jax.ShapeDtypeStruct(
|
||||
shape=(16, 16),
|
||||
dtype=jnp.dtype('float32'),
|
||||
sharding=s)
|
||||
lowered = fn.lower(x_shape)
|
||||
serialized, in_tree, out_tree = serialize(lowered.compile())
|
||||
compiled = deserialize_and_load(serialized, in_tree, out_tree)
|
||||
return compiled
|
||||
|
||||
topo = topologies.get_attached_topology()
|
||||
n = max(1, len(topo.devices) // 2)
|
||||
mesh_shape = (len(topo.devices) // n, n)
|
||||
|
||||
mesh = topologies.make_mesh(topo, mesh_shape, ('x', 'y'))
|
||||
ref_mesh = jax.sharding.Mesh(
|
||||
mesh_utils.create_device_mesh(
|
||||
mesh_shape, jax.devices()),
|
||||
('x', 'y'))
|
||||
self.assertEqual(lower_and_load(ref_mesh).as_text(),
|
||||
lower_and_load(mesh).as_text())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user