mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove the usage of _arrays
from tests
PiperOrigin-RevId: 505871063
This commit is contained in:
parent
82bd889120
commit
8a4de1f86a
@ -982,8 +982,7 @@ class BufferDonationTestCase(JaxTestCase):
|
||||
|
||||
def _assertDeleted(self, x, deleted):
|
||||
if hasattr(x, "_arrays"):
|
||||
for buffer in x._arrays:
|
||||
self.assertEqual(buffer.is_deleted(), deleted)
|
||||
self.assertEqual(x.is_deleted(), deleted)
|
||||
elif hasattr(x, "device_buffer"):
|
||||
self.assertEqual(x.device_buffer.is_deleted(), deleted)
|
||||
else:
|
||||
|
@ -89,10 +89,7 @@ def _get_metadata(arr):
|
||||
dtype = 'bfloat16'
|
||||
else:
|
||||
dtype = np.dtype(arr.dtype).str
|
||||
if isinstance(arr, array.ArrayImpl):
|
||||
local_shape = arr._arrays[0].shape
|
||||
else:
|
||||
local_shape = arr.addressable_data(0).shape
|
||||
local_shape = arr.addressable_data(0).shape
|
||||
return {
|
||||
'compressor': {
|
||||
'id': 'gzip'
|
||||
|
@ -721,7 +721,7 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
if config.jax_array:
|
||||
out = jitted_f(2)
|
||||
self.assertIsInstance(out.sharding, sharding.SingleDeviceSharding)
|
||||
self.assertIsInstance(out._arrays[0], device_array.Buffer)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
else:
|
||||
self.assertIsInstance(jitted_f(2), device_array.Buffer)
|
||||
|
||||
|
@ -94,8 +94,8 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
arr, global_data = create_array(
|
||||
input_shape, sharding.NamedSharding(global_mesh, mesh_axes))
|
||||
for s in arr.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], global_data[s.index])
|
||||
self.assertTrue(dispatch.is_single_device_sharding(s.data.sharding))
|
||||
self.assertArraysEqual(s.data, global_data[s.index])
|
||||
self.assertArraysEqual(arr._value, global_data)
|
||||
self.assertArraysEqual(arr._npy_value, global_data)
|
||||
|
||||
@ -418,7 +418,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
|
||||
y = jax.pmap(jnp.sin)(x)
|
||||
self.assertArraysEqual([a.device() for a in y],
|
||||
[a.device() for a in y._arrays])
|
||||
y.sharding._device_assignment)
|
||||
|
||||
sin_x = iter(np.sin(x))
|
||||
for i, j in zip(iter(y), sin_x):
|
||||
|
@ -2704,8 +2704,8 @@ class LazyConstantTest(jtu.JaxTestCase):
|
||||
y = lax.convert_element_type(x, dtype_out)
|
||||
self.assertEqual(y.dtype, dtype_out)
|
||||
if config.jax_array:
|
||||
x_buf = x._arrays[0]
|
||||
y_buf = y._arrays[0]
|
||||
x_buf = x
|
||||
y_buf = y
|
||||
else:
|
||||
x_buf = x.device_buffer
|
||||
y_buf = y.device_buffer
|
||||
|
@ -379,8 +379,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.ArrayImpl)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
|
||||
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
||||
# Annotation from with_sharding_constraint
|
||||
@ -408,8 +407,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertIsInstance(actual, array.ArrayImpl)
|
||||
self.assertLen(actual.addressable_shards, 2)
|
||||
self.assertAllClose(np.asarray(actual._arrays[0]), expected,
|
||||
check_dtypes=False)
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
|
||||
hlo = f.lower(np.ones(shape)).compiler_ir(dialect="hlo")
|
||||
# Annotation from with_sharding_constraint
|
||||
@ -1745,8 +1743,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertLen(s.data.devices(), 1)
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(out._value, expected_matrix_mul)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@ -1772,8 +1770,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (8, 8))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, shard_shape)
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertLen(s.data.devices(), 1)
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(out._value, expected_matrix_mul)
|
||||
|
||||
def test_numpy_array_input_assume_fully_replicated(self):
|
||||
@ -1792,7 +1790,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
def test_numpy_array_input(self):
|
||||
@ -1811,7 +1809,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertEqual(s.data.shape, (2, 1))
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
@jax_array(True)
|
||||
@ -1823,8 +1821,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.shape, (8, 2))
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
for s in out.addressable_shards:
|
||||
self.assertLen(s.data._arrays, 1)
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertLen(s.data.devices(), 1)
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
self.assertArraysEqual(out._value, input_data)
|
||||
|
||||
global_input_shape = (8, 2)
|
||||
@ -1872,25 +1870,25 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(
|
||||
s.data._arrays[0], (input_data @ input_data.T)[s.index])
|
||||
s.data, (input_data @ input_data.T)[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, s2_shape)
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
self.assertIsInstance(out3, array.ArrayImpl)
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index])
|
||||
self.assertArraysEqual(s.data, (input_data * 2)[s.index])
|
||||
|
||||
self.assertIsInstance(out4, array.ArrayImpl)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
self.assertEqual(out4.addressable_shards[0].data.shape, s4_shape)
|
||||
for s in out4.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data)
|
||||
self.assertArraysEqual(s.data, input_data)
|
||||
|
||||
def test_in_axis_resources_mismatch_error(self):
|
||||
global_input_shape = (8, 2)
|
||||
@ -2205,10 +2203,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
f = pjit(lambda x: x, out_axis_resources=s)
|
||||
out = f(arr)
|
||||
self.assertArraysEqual([o.device() for o in out._arrays], list(mesh.devices.flat))
|
||||
self.assertTrue(out.sharding.is_equivalent_to(arr.sharding, arr.ndim))
|
||||
self.assertArraysEqual(out, inp_data)
|
||||
out2 = f(out)
|
||||
self.assertArraysEqual([o.device() for o in out2._arrays], list(mesh.devices.flat))
|
||||
self.assertTrue(out2.sharding.is_equivalent_to(out.sharding, out.ndim))
|
||||
self.assertArraysEqual(out2, inp_data)
|
||||
|
||||
@jax_array(True)
|
||||
|
@ -1248,12 +1248,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
# Test that 'ans' was properly replicated across devices.
|
||||
if config.jax_array:
|
||||
bufs = ans._arrays
|
||||
ans_devices = ans.sharding._device_assignment
|
||||
else:
|
||||
bufs = ans.device_buffers
|
||||
ans_devices = [b.device() for b in ans.device_buffers]
|
||||
# TODO(mattjj,sharadmv): fix physical layout with eager pmap, remove 'if'
|
||||
if not config.jax_disable_jit:
|
||||
self.assertEqual([b.device() for b in bufs], devices)
|
||||
self.assertEqual(ans_devices, devices)
|
||||
|
||||
def testPmapConstantError(self):
|
||||
device_count = jax.device_count()
|
||||
@ -1288,25 +1288,25 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# Test that 'ans' was properly replicated across devices.
|
||||
expected_sharded = self.pmap(self.pmap(lambda x: x))(expected)
|
||||
if config.jax_array:
|
||||
ans_db = ans._arrays
|
||||
expected_db = expected_sharded._arrays
|
||||
self.assertTrue(ans.sharding._device_assignment,
|
||||
expected_sharded.sharding._device_assignment)
|
||||
else:
|
||||
ans_db = ans.device_buffers
|
||||
expected_db = expected_sharded.device_buffers
|
||||
self.assertEqual([b.device() for b in ans_db],
|
||||
[b.device() for b in expected_db])
|
||||
self.assertEqual([b.device() for b in ans_db],
|
||||
[b.device() for b in expected_db])
|
||||
|
||||
f = self.pmap(self.pmap(lambda x: (x, 3)))
|
||||
x_sharded, ans = f(x)
|
||||
if config.jax_array:
|
||||
ans_db = ans._arrays
|
||||
x_sharded_db = x_sharded._arrays
|
||||
self.assertEqual(ans.sharding._device_assignment,
|
||||
x_sharded.sharding._device_assignment)
|
||||
else:
|
||||
ans_db = ans.device_buffers
|
||||
x_sharded_db = x_sharded.device_buffers
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
self.assertEqual([b.device() for b in ans_db],
|
||||
[b.device() for b in x_sharded_db])
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
self.assertEqual([b.device() for b in ans_db],
|
||||
[b.device() for b in x_sharded_db])
|
||||
|
||||
@unittest.skip("Nested pmaps with devices not yet implemented")
|
||||
def testNestedPmapConstantDevices(self):
|
||||
@ -1327,13 +1327,12 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
# Test that 'ans' was properly replicated across devices.
|
||||
expected_sharded = self.pmap(self.pmap(lambda x: x), devices=devices)(expected)
|
||||
if config.jax_array:
|
||||
ans_bufs = ans._arrays
|
||||
expected_sharded_bufs = expected_sharded._arrays
|
||||
self.assertTrue(ans.sharding == expected_sharded.sharding)
|
||||
else:
|
||||
ans_bufs = ans.device_buffers
|
||||
expected_sharded_bufs = expected_sharded.device_buffers
|
||||
self.assertEqual([b.device() for b in ans_bufs],
|
||||
[b.device() for b in expected_sharded_bufs])
|
||||
self.assertEqual([b.device() for b in ans_bufs],
|
||||
[b.device() for b in expected_sharded_bufs])
|
||||
|
||||
def testNestedPmapConstantError(self):
|
||||
if config.jax_disable_jit:
|
||||
@ -2954,7 +2953,7 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], expected[s.index])
|
||||
self.assertArraysEqual(s.data, expected[s.index])
|
||||
self.assertArraysEqual(out, expected)
|
||||
|
||||
def test_pmap_double_input_array_output_array(self):
|
||||
@ -2973,8 +2972,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
for s1, s2 in safe_zip(out1.addressable_shards, out2.addressable_shards):
|
||||
self.assertArraysEqual(s1.data._arrays[0], input_data[s1.index])
|
||||
self.assertArraysEqual(s2.data._arrays[0], input_data[s2.index])
|
||||
self.assertArraysEqual(s1.data, input_data[s1.index])
|
||||
self.assertArraysEqual(s2.data, input_data[s2.index])
|
||||
self.assertArraysEqual(out1, input_data)
|
||||
self.assertArraysEqual(out2, input_data)
|
||||
|
||||
@ -2999,8 +2998,8 @@ class ArrayPmapTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out1.shape, (2,))
|
||||
self.assertEqual(out2.shape, (dc, dc, 2))
|
||||
for i, (s1, s2) in enumerate(safe_zip(out1.addressable_shards, out2.addressable_shards)):
|
||||
self.assertArraysEqual(s1.data._arrays[0], input_data[i])
|
||||
self.assertArraysEqual(s2.data._arrays[0], input_data)
|
||||
self.assertArraysEqual(s1.data, input_data[i])
|
||||
self.assertArraysEqual(s2.data, input_data)
|
||||
|
||||
def test_pmap_array_sharding_mismatch(self):
|
||||
input_shape = (jax.device_count(), 2)
|
||||
|
@ -1277,7 +1277,7 @@ class XMapArrayTest(XMapTestCase):
|
||||
self.assertEqual(out.addressable_shards[0].data.shape, (2, 1))
|
||||
self.assertDictEqual(out.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(s.data, input_data[s.index])
|
||||
|
||||
def test_xmap_array_double_input(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
@ -1301,14 +1301,14 @@ class XMapArrayTest(XMapTestCase):
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, (2,))
|
||||
self.assertDictEqual(out1.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out2.shape, (8,))
|
||||
self.assertEqual(out2.addressable_shards[0].data.shape, (4,))
|
||||
self.assertDictEqual(out2.sharding.mesh.shape, {'x': 4, 'y': 2})
|
||||
for s in out2.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], expected_matrix_mul[s.index])
|
||||
self.assertArraysEqual(s.data, expected_matrix_mul[s.index])
|
||||
|
||||
def test_xmap_array_sharding_mismatch(self):
|
||||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user