mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Clearer test names.
This commit is contained in:
parent
5077807c8b
commit
defe71228c
@ -1639,7 +1639,7 @@ class PileTest(jtu.JaxTestCase):
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5, 7), 1)
|
||||
self.assertAllClose(p.data, data)
|
||||
|
||||
def test_einsum_ragged_tensor(self):
|
||||
def test_einsum_with_ragged_tensor_dimension(self):
|
||||
x_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
def fprop_layer(x_size):
|
||||
one_d = jnp.arange(x_size, dtype='int32')
|
||||
@ -1652,7 +1652,7 @@ class PileTest(jtu.JaxTestCase):
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[3,bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+,2,7\]')
|
||||
self.assertEqual(p.data.shape, (3, 3, 5, 2, 7))
|
||||
|
||||
def test_einsum_ragged_tensor_and_contract(self):
|
||||
def test_einsum_with_ragged_tensor_and_contract_dimensions(self):
|
||||
ragged_sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
def fprop_layer(ragged_size):
|
||||
one_d = jnp.arange(ragged_size, dtype='int32')
|
||||
|
Loading…
x
Reference in New Issue
Block a user