Clearer test names.

This commit is contained in:
Alexey Radul 2023-06-23 18:00:10 -04:00
parent 5077807c8b
commit defe71228c

View File

@ -1639,7 +1639,7 @@ class PileTest(jtu.JaxTestCase):
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
self.assertAllClose(p.data, data)
def test_einsum_ragged_tensor(self):
def test_einsum_with_ragged_tensor_dimension(self):
x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def fprop_layer(x_size):
one_d = jnp.arange(x_size, dtype='int32')
@ -1652,7 +1652,7 @@ class PileTest(jtu.JaxTestCase):
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))
def test_einsum_ragged_tensor_and_contract(self):
def test_einsum_with_ragged_tensor_and_contract_dimensions(self):
ragged_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
def fprop_layer(ragged_size):
one_d = jnp.arange(ragged_size, dtype='int32')