mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add internal jaxlib function for fetching the topology from
a set of devices. We may want to make this topology serializable or usable as a cache key. PiperOrigin-RevId: 552931150
This commit is contained in:
parent
0116d196a7
commit
614bbcc626
@ -16,19 +16,22 @@
|
||||
import contextlib
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import config
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import topologies
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
from jax.experimental import topologies
|
||||
deserialize_and_load,
|
||||
serialize,
|
||||
)
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import numpy as np
|
||||
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
prev_xla_flags = None
|
||||
@ -100,6 +103,20 @@ class JaxAotTest(jtu.JaxTestCase):
|
||||
lower_and_load(ref_mesh).as_text(), lower_and_load(aot_mesh).as_text()
|
||||
)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 175, 'Test requires jaxlib 0.4.15')
|
||||
def test_get_topology_from_devices(self):
|
||||
try:
|
||||
aot_topo = topologies.get_topology_desc(
|
||||
platform=jax.devices()[0].platform
|
||||
)
|
||||
except NotImplementedError:
|
||||
raise unittest.SkipTest('PJRT Topology not supported')
|
||||
|
||||
topo = xc.get_topology_for_devices(aot_topo.devices)
|
||||
self.assertEqual(
|
||||
topo.platform_version, aot_topo.devices[0].client.platform_version
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user