Add internal jaxlib function for fetching the topology from

a set of devices. We may want to make this topology serializable
or usable as a cache key.

PiperOrigin-RevId: 552931150
This commit is contained in:
Parker Schuh 2023-08-01 14:53:24 -07:00 committed by jax authors
parent 0116d196a7
commit 614bbcc626

View File

@ -16,19 +16,22 @@
import contextlib
import unittest
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.experimental import topologies
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
from jax.experimental import topologies
deserialize_and_load,
serialize,
)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
from jax import config
config.parse_flags_with_absl()
prev_xla_flags = None
@ -100,6 +103,20 @@ class JaxAotTest(jtu.JaxTestCase):
lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text()
)
@unittest.skipIf(xla_extension_version < 175, 'Test requires jaxlib 0.4.15')
def test_get_topology_from_devices(self):
try:
aot_topo = topologies.get_topology_desc(
platform=jax.devices()[0].platform
)
except NotImplementedError:
raise unittest.SkipTest('PJRT Topology not supported')
topo = xc.get_topology_for_devices(aot_topo.devices)
self.assertEqual(
topo.platform_version, aot_topo.devices[0].client.platform_version
)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())