Merge pull request #12107 from jakevdp:fix-pmap-test

PiperOrigin-RevId: 470062747
This commit is contained in:
jax authors 2022-08-25 13:01:24 -07:00
commit 7219cdc9aa

View File

@ -137,6 +137,8 @@ class PythonPmapTest(jtu.JaxTestCase):
def pmap(self):
return src_api._python_pmap
# TODO(yashkatariya): Re-enable when unsafe_buffer_pointer is implemented
@unittest.skipIf(config.jax_array, "Array does not yet implement unsafe_buffer_pointer")
def testDeviceBufferToArray(self):
sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2)))
@ -144,12 +146,7 @@ class PythonPmapTest(jtu.JaxTestCase):
# sda.device_buffers, which isn't supported, and instead ensure fast slices
# of the arrays returned by pmap are set up correctly.
# buf = sda.device_buffers[-1]
# TODO(yashkatariya): Don't read the private `_arrays` method. When devices()
# is exposed on Array, use that here.
if config.jax_array:
buf = sda[-1]._arrays[0]
else:
buf = sda[-1]
buf = sda[-1]
view = jnp.array(buf, copy=False)
self.assertArraysEqual(sda[-1], view)
@ -158,13 +155,8 @@ class PythonPmapTest(jtu.JaxTestCase):
copy = jnp.array(buf, copy=True)
self.assertArraysEqual(sda[-1], copy)
if config.jax_array:
self.assertEqual(buf.device(), copy._arrays[0].device())
self.assertNotEqual(buf.unsafe_buffer_pointer(),
copy._arrays[0].unsafe_buffer_pointer())
else:
self.assertEqual(buf.device(), copy.device())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
self.assertEqual(buf.device(), copy.device())
self.assertNotEqual(buf.unsafe_buffer_pointer(), copy.unsafe_buffer_pointer())
def _getMeshShape(self, device_mesh_shape):
device_count = jax.device_count()