Merge pull request #9290 from fehiepsi:named

PiperOrigin-RevId: 438290209
This commit is contained in:
jax authors 2022-03-30 06:54:10 -07:00
commit 17fc5bd02e
2 changed files with 8 additions and 1 deletions

View File

@ -1679,7 +1679,7 @@ class NamedShape:
return (self.__positional, self.__named) == (other.__positional, other.__named)
if isinstance(other, tuple):
return not self.__named and self.__positional == other
raise TypeError(f"NamedShape doesn't support comparisons with {type(other)}")
return False
def __hash__(self):
named = frozenset(self.__named.items())

View File

@ -516,6 +516,13 @@ class JaxprTypeChecks(jtu.JaxTestCase):
aval3 = core.ShapedArray((2, 3), np.float32, False, {'i': 5})
self.assertFalse(core.typecompat(aval1, aval3))
def test_named_shape_comparision(self):
self.assertTrue(core.NamedShape(2, 3) == (2, 3))
self.assertFalse(core.NamedShape(2, i=3) == (2,))
self.assertFalse(core.NamedShape(2, i=3) == (2, 3))
self.assertFalse(core.NamedShape(2, i=3) == None)
self.assertFalse(core.NamedShape() == [])
@jtu.with_config(jax_dynamic_shapes=True)
class DynamicShapesTest(jtu.JaxTestCase):