mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Merge pull request #9290 from fehiepsi:named
PiperOrigin-RevId: 438290209
This commit is contained in:
commit
17fc5bd02e
@ -1679,7 +1679,7 @@ class NamedShape:
|
||||
return (self.__positional, self.__named) == (other.__positional, other.__named)
|
||||
if isinstance(other, tuple):
|
||||
return not self.__named and self.__positional == other
|
||||
raise TypeError(f"NamedShape doesn't support comparisons with {type(other)}")
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
named = frozenset(self.__named.items())
|
||||
|
@ -516,6 +516,13 @@ class JaxprTypeChecks(jtu.JaxTestCase):
|
||||
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
|
||||
self.assertFalse(core.typecompat(aval1, aval3))
|
||||
|
||||
def test_named_shape_comparision(self):
|
||||
self.assertTrue(core.NamedShape(2, 3) == (2, 3))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == (2,))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == (2, 3))
|
||||
self.assertFalse(core.NamedShape(2, i=3) == None)
|
||||
self.assertFalse(core.NamedShape() == [])
|
||||
|
||||
|
||||
@jtu.with_config(jax_dynamic_shapes=True)
|
||||
class DynamicShapesTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user