mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #12107 from jakevdp:fix-pmap-test
PiperOrigin-RevId: 470062747
This commit is contained in:
commit
7219cdc9aa
@ -137,6 +137,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
def pmap(self):
|
||||
return src_api._python_pmap
|
||||
|
||||
# TODO(yashkatariya): Re-enable when unsafe_buffer_pointer is implemented
|
||||
@unittest.skipIf(config.jax_array, "Array does not yet implement unsafe_buffer_pointer")
|
||||
def testDeviceBufferToArray(self):
|
||||
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
|
||||
|
||||
@ -144,12 +146,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# sda.device_buffers, which isn't supported, and instead ensure fast slices
|
||||
# of the arrays returned by pmap are set up correctly.
|
||||
# buf = sda.device_buffers[-1]
|
||||
# TODO(yashkatariya): Don't read the private `_arrays` method. When devices()
|
||||
# is exposed on Array, use that here.
|
||||
if config.jax_array:
|
||||
buf = sda[-1]._arrays[0]
|
||||
else:
|
||||
buf = sda[-1]
|
||||
buf = sda[-1]
|
||||
|
||||
view = jnp.array(buf, copy=False)
|
||||
self.assertArraysEqual(sda[-1], view)
|
||||
@ -158,13 +155,8 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
copy = jnp.array(buf, copy=True)
|
||||
self.assertArraysEqual(sda[-1], copy)
|
||||
if config.jax_array:
|
||||
self.assertEqual(buf.device(), copy._arrays[0].device())
|
||||
self.assertNotEqual(buf.unsafe_buffer_pointer(),
|
||||
copy._arrays[0].unsafe_buffer_pointer())
|
||||
else:
|
||||
self.assertEqual(buf.device(), copy.device())
|
||||
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
|
||||
self.assertEqual(buf.device(), copy.device())
|
||||
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
|
||||
|
||||
def _getMeshShape(self, device_mesh_shape):
|
||||
device_count = jax.device_count()
|
||||
|
Loading…
x
Reference in New Issue
Block a user