Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class.

```
Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count]
         Expected: (self)
  Actually passed: (self, _, _)
```

PiperOrigin-RevId: 441211351
This commit is contained in:
Yash Katariya 2022-04-12 09:34:58 -07:00 committed by jax authors
parent a2c2d9af91
commit 3136004c62

View File

@ -2188,6 +2188,10 @@ class PartitionSpec(tuple):
create a separate class for this so JAX's pytree utilities can distinguish it
from a tuple that should be treated as a pytree.
"""
def __init__(self, *partitions):
pass
def __new__(cls, *partitions):
return tuple.__new__(PartitionSpec, partitions)