Remove the usage of _arrays from tests

PiperOrigin-RevId: 505871063
This commit is contained in:
Yash Katariya 2023-01-30 20:01:58 -08:00 committed by jax authors
parent 82bd889120
commit 8a4de1f86a
8 changed files with 47 additions and 54 deletions

View File

@ -982,8 +982,7 @@ class BufferDonationTestCase(JaxTestCase):
def _assertDeleted(self, x, deleted):
if hasattr(x, "_arrays"):
for buffer in x._arrays:
self.assertEqual(buffer.is_deleted(), deleted)
self.assertEqual(x.is_deleted(), deleted)
elif hasattr(x, "device_buffer"):
self.assertEqual(x.device_buffer.is_deleted(), deleted)
else:

View File

@ -89,10 +89,7 @@ def _get_metadata(arr):
dtype = 'bfloat16'
else:
dtype = np.dtype(arr.dtype).str
if isinstance(arr, array.ArrayImpl):
local_shape = arr._arrays[0].shape
else:
local_shape = arr.addressable_data(0).shape
local_shape = arr.addressable_data(0).shape
return {
'compressor': {
'id': 'gzip'

View File

@ -721,7 +721,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
if config.jax_array:
out = jitted_f(2)
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
self.assertIsInstance(out._arrays[0], device_array.Buffer)
self.assertIsInstance(out, array.ArrayImpl)
else:
self.assertIsInstance(jitted_f(2), device_array.Buffer)

View File

@ -94,8 +94,8 @@ class JaxArrayTest(jtu.JaxTestCase):
arr, global_data = create_array(
input_shape, sharding.NamedSharding(global_mesh, mesh_axes))
for s in arr.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
self.assertArraysEqual(s.data, global_data[s.index])
self.assertArraysEqual(arr._value, global_data)
self.assertArraysEqual(arr._npy_value, global_data)
@ -418,7 +418,7 @@ class JaxArrayTest(jtu.JaxTestCase):
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
[a.device() for a in y._arrays])
y.sharding._device_assignment)
sin_x = iter(np.sin(x))
for i, j in zip(iter(y), sin_x):

View File

@ -2704,8 +2704,8 @@ class LazyConstantTest(jtu.JaxTestCase):
y = lax.convert_element_type(x, dtype_out)
self.assertEqual(y.dtype, dtype_out)
if config.jax_array:
x_buf = x._arrays[0]
y_buf = y._arrays[0]
x_buf = x
y_buf = y
else:
x_buf = x.device_buffer
y_buf = y.device_buffer

View File

@ -379,8 +379,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.ArrayImpl)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
self.assertAllClose(actual, expected, check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
# Annotation from with_sharding_constraint
@ -408,8 +407,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertIsInstance(actual, array.ArrayImpl)
self.assertLen(actual.addressable_shards, 2)
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
check_dtypes=False)
self.assertAllClose(actual, expected, check_dtypes=False)
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
# Annotation from with_sharding_constraint
@ -1745,8 +1743,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertLen(s.data.devices(), 1)
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
@parameterized.named_parameters(
@ -1772,8 +1770,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertLen(s.data.devices(), 1)
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertArraysEqual(out._value, expected_matrix_mul)
def test_numpy_array_input_assume_fully_replicated(self):
@ -1792,7 +1790,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(s.data, input_data[s.index])
self.assertArraysEqual(out._value, input_data)
def test_numpy_array_input(self):
@ -1811,7 +1809,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertEqual(s.data.shape, (2, 1))
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(s.data, input_data[s.index])
self.assertArraysEqual(out._value, input_data)
@jax_array(True)
@ -1823,8 +1821,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out.shape, (8, 2))
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
for s in out.addressable_shards:
self.assertLen(s.data._arrays, 1)
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertLen(s.data.devices(), 1)
self.assertArraysEqual(s.data, input_data[s.index])
self.assertArraysEqual(out._value, input_data)
global_input_shape = (8, 2)
@ -1872,25 +1870,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(
s.data._arrays[0], (input_data @ input_data.T)[s.index])
s.data, (input_data @ input_data.T)[s.index])
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out2.shape, (8, 2))
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(s.data, input_data[s.index])
self.assertIsInstance(out3, array.ArrayImpl)
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index])
self.assertArraysEqual(s.data, (input_data * 2)[s.index])
self.assertIsInstance(out4, array.ArrayImpl)
self.assertEqual(out4.shape, (8, 2))
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
for s in out4.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data)
self.assertArraysEqual(s.data, input_data)
def test_in_axis_resources_mismatch_error(self):
global_input_shape = (8, 2)
@ -2205,10 +2203,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
f = pjit(lambda x: x, out_axis_resources=s)
out = f(arr)
self.assertArraysEqual([o.device() for o in out._arrays], list(mesh.devices.flat))
self.assertTrue(out.sharding.is_equivalent_to(arr.sharding, arr.ndim))
self.assertArraysEqual(out, inp_data)
out2 = f(out)
self.assertArraysEqual([o.device() for o in out2._arrays], list(mesh.devices.flat))
self.assertTrue(out2.sharding.is_equivalent_to(out.sharding, out.ndim))
self.assertArraysEqual(out2, inp_data)
@jax_array(True)

View File

@ -1248,12 +1248,12 @@ class PythonPmapTest(jtu.JaxTestCase):
# Test that 'ans' was properly replicated across devices.
if config.jax_array:
bufs = ans._arrays
ans_devices = ans.sharding._device_assignment
else:
bufs = ans.device_buffers
ans_devices = [b.device() for b in ans.device_buffers]
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
if not config.jax_disable_jit:
self.assertEqual([b.device() for b in bufs], devices)
self.assertEqual(ans_devices, devices)
def testPmapConstantError(self):
device_count = jax.device_count()
@ -1288,25 +1288,25 @@ class PythonPmapTest(jtu.JaxTestCase):
# Test that 'ans' was properly replicated across devices.
expected_sharded = self.pmap(self.pmap(lambda x: x))(expected)
if config.jax_array:
ans_db = ans._arrays
expected_db = expected_sharded._arrays
self.assertTrue(ans.sharding._device_assignment,
expected_sharded.sharding._device_assignment)
else:
ans_db = ans.device_buffers
expected_db = expected_sharded.device_buffers
self.assertEqual([b.device() for b in ans_db],
[b.device() for b in expected_db])
self.assertEqual([b.device() for b in ans_db],
[b.device() for b in expected_db])
f = self.pmap(self.pmap(lambda x: (x, 3)))
x_sharded, ans = f(x)
if config.jax_array:
ans_db = ans._arrays
x_sharded_db = x_sharded._arrays
self.assertEqual(ans.sharding._device_assignment,
x_sharded.sharding._device_assignment)
else:
ans_db = ans.device_buffers
x_sharded_db = x_sharded.device_buffers
self.assertAllClose(ans, expected, check_dtypes=False)
self.assertEqual([b.device() for b in ans_db],
[b.device() for b in x_sharded_db])
self.assertAllClose(ans, expected, check_dtypes=False)
self.assertEqual([b.device() for b in ans_db],
[b.device() for b in x_sharded_db])
@unittest.skip("Nested pmaps with devices not yet implemented")
def testNestedPmapConstantDevices(self):
@ -1327,13 +1327,12 @@ class PythonPmapTest(jtu.JaxTestCase):
# Test that 'ans' was properly replicated across devices.
expected_sharded = self.pmap(self.pmap(lambda x: x), devices=devices)(expected)
if config.jax_array:
ans_bufs = ans._arrays
expected_sharded_bufs = expected_sharded._arrays
self.assertTrue(ans.sharding == expected_sharded.sharding)
else:
ans_bufs = ans.device_buffers
expected_sharded_bufs = expected_sharded.device_buffers
self.assertEqual([b.device() for b in ans_bufs],
[b.device() for b in expected_sharded_bufs])
self.assertEqual([b.device() for b in ans_bufs],
[b.device() for b in expected_sharded_bufs])
def testNestedPmapConstantError(self):
if config.jax_disable_jit:
@ -2954,7 +2953,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.assertIsInstance(out, array.ArrayImpl)
for s in out.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected[s.index])
self.assertArraysEqual(s.data, expected[s.index])
self.assertArraysEqual(out, expected)
def test_pmap_double_input_array_output_array(self):
@ -2973,8 +2972,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.assertIsInstance(out1, array.ArrayImpl)
self.assertIsInstance(out2, array.ArrayImpl)
for s1, s2 in safe_zip(out1.addressable_shards, out2.addressable_shards):
self.assertArraysEqual(s1.data._arrays[0], input_data[s1.index])
self.assertArraysEqual(s2.data._arrays[0], input_data[s2.index])
self.assertArraysEqual(s1.data, input_data[s1.index])
self.assertArraysEqual(s2.data, input_data[s2.index])
self.assertArraysEqual(out1, input_data)
self.assertArraysEqual(out2, input_data)
@ -2999,8 +2998,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
self.assertEqual(out1.shape, (2,))
self.assertEqual(out2.shape, (dc, dc, 2))
for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)):
self.assertArraysEqual(s1.data._arrays[0], input_data[i])
self.assertArraysEqual(s2.data._arrays[0], input_data)
self.assertArraysEqual(s1.data, input_data[i])
self.assertArraysEqual(s2.data, input_data)
def test_pmap_array_sharding_mismatch(self):
input_shape = (jax.device_count(), 2)

View File

@ -1277,7 +1277,7 @@ class XMapArrayTest(XMapTestCase):
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
self.assertDictEqual(out.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(s.data, input_data[s.index])
def test_xmap_array_double_input(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
@ -1301,14 +1301,14 @@ class XMapArrayTest(XMapTestCase):
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
self.assertDictEqual(out1.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out2.shape, (8,))
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
self.assertDictEqual(out2.sharding.mesh.shape, {'x': 4, 'y': 2})
for s in out2.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
def test_xmap_array_sharding_mismatch(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))