Updated the repr of GPU devices to be more consistent with TPUs/CPUs.

For example, `cuda(id=0)` will now be `CudaDevice(id=0)`

PiperOrigin-RevId: 651393690
This commit is contained in:
Gleb Pobudzey 2024-07-11 06:53:32 -07:00 committed by jax authors
parent e8d9a54b1b
commit 46103f6ff3
3 changed files with 65 additions and 0 deletions

View File

@ -21,6 +21,8 @@ Remember to align the itemized text with the first line of an item within a list
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
be passed *before* `index_map`. The old argument order is deprecated and
will be removed in a future release.
* Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example, `cuda(id=0)` will now be `CudaDevice(id=0)`.
* Deprecations
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,

View File

@ -38,6 +38,11 @@ jax_test(
tags = ["test_cpu_thunks"],
)
jax_test(
name = "device_test",
srcs = ["device_test.py"],
)
jax_test(
name = "dynamic_api_test",
srcs = ["dynamic_api_test.py"],

58
tests/device_test.py Normal file
View File

@ -0,0 +1,58 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
jax.config.parse_flags_with_absl()
class DeviceTest(jtu.JaxTestCase):
def test_repr(self):
device = jax.devices()[0]
# TODO(pobudzey): Add a test for rocm devices when available.
if jtu.is_device_cuda():
if xla_extension_version < 276:
self.skipTest('requires jaxlib 0.4.31')
self.assertEqual(device.platform, 'gpu')
self.assertEqual(repr(device), 'CudaDevice(id=0)')
elif jtu.test_device_matches(['tpu']):
self.assertEqual(device.platform, 'tpu')
self.assertEqual(
repr(device),
'TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)',
)
elif jtu.test_device_matches(['cpu']):
self.assertEqual(device.platform, 'cpu')
self.assertEqual(repr(device), 'CpuDevice(id=0)')
def test_str(self):
device = jax.devices()[0]
# TODO(pobudzey): Add a test for rocm devices when available.
if jtu.is_device_cuda():
self.assertEqual(str(device), 'cuda:0')
elif jtu.test_device_matches(['tpu']):
self.assertEqual(str(device), 'TPU_0(process=0,(0,0,0,0))')
elif jtu.test_device_matches(['cpu']):
self.assertEqual(str(device), 'TFRT_CPU_0')
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())