mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00

We're switching to the new terminology to avoid confusion in cases where multiple jax processes are running on a single host, and each process has a unique process_index/host_id. This keeps aliases for the old `host_id` APIs for now, but these will eventually be removed. This was originally commited in b77ef5138b631378e6a8ceb8bafc94fe91239bae, but reverted in 14acd070c2afb11c81fc91f43790577cd48cbf67 due to Google-internal test failures from renaming the local_devices argument name. This change is identical except it also adds staging for the argument name change.
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
# Copyright 2019 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from absl.testing import absltest
|
|
from jax.lib import xla_bridge as xb
|
|
from jax.lib import xla_client as xc
|
|
from jax import test_util as jtu
|
|
|
|
|
|
class XlaBridgeTest(absltest.TestCase):
|
|
|
|
def test_set_device_assignment_no_partition(self):
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
|
|
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
|
|
"0 1 2 3 \n")
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
expected_device_assignment)
|
|
|
|
def test_set_device_assignment_with_partition(self):
|
|
compile_options = xb.get_compile_options(
|
|
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
|
|
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
|
|
"0 2 \nComputation 1: 1 3 \n")
|
|
self.assertEqual(compile_options.device_assignment.__repr__(),
|
|
expected_device_assignment)
|
|
|
|
def test_parameter_replication_default(self):
|
|
c = xb.make_computation_builder("test")
|
|
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
|
|
built_c = c.Build()
|
|
assert "replication" not in built_c.as_hlo_text()
|
|
|
|
def test_parameter_replication(self):
|
|
c = xb.make_computation_builder("test")
|
|
_ = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", False)
|
|
built_c = c.Build()
|
|
assert "parameter_replication={false}" in built_c.as_hlo_text()
|
|
|
|
def test_local_devices(self):
|
|
self.assertNotEmpty(xb.local_devices())
|
|
with self.assertRaisesRegex(ValueError, "Unknown process_index 100"):
|
|
xb.local_devices(100)
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
|
|
xb.local_devices(backend="foo")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|