Cleanup: switch to new version of super()

This commit is contained in:
Jake VanderPlas 2021-08-05 13:11:07 -07:00
parent 03ec444f0a
commit 63a788b4de
6 changed files with 11 additions and 11 deletions

View File

@ -1062,7 +1062,7 @@ class ShapedArray(UnshapedArray):
array_abstraction_level = 1
def __init__(self, shape, dtype, weak_type=False, named_shape={}):
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
super().__init__(dtype, weak_type=weak_type)
self.shape = canonicalize_shape(shape)
self.named_shape = dict(named_shape)
@ -1141,8 +1141,8 @@ class ConcreteArray(ShapedArray):
array_abstraction_level = 0
def __init__(self, val, weak_type=False):
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
super().__init__(np.shape(val), np.result_type(val),
weak_type=weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
self.val = val
assert self.dtype != np.dtype('O'), val

View File

@ -228,7 +228,7 @@ def wrap_init(f, params={}) -> WrappedFun:
class _CacheLocalContext(threading.local):
def __init__(self):
super(_CacheLocalContext, self).__init__()
super().__init__()
self.most_recent_entry = None

View File

@ -879,7 +879,7 @@ class JaxTestCase(parameterized.TestCase):
# assert core.reset_trace_state()
def setUp(self):
super(JaxTestCase, self).setUp()
super().setUp()
config.update('jax_enable_checks', True)
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.

View File

@ -61,7 +61,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
class DLPackTest(jtu.JaxTestCase):
def setUp(self):
super(DLPackTest, self).setUp()
super().setUp()
if jtu.device_under_test() == "tpu":
self.skipTest("DLPack not supported on TPU")
@ -194,7 +194,7 @@ class DLPackTest(jtu.JaxTestCase):
class CudaArrayInterfaceTest(jtu.JaxTestCase):
def setUp(self):
super(CudaArrayInterfaceTest, self).setUp()
super().setUp()
if jtu.device_under_test() != "gpu":
self.skipTest("__cuda_array_interface__ is only supported on GPU")

View File

@ -62,7 +62,7 @@ class AbstractSparseArray(core.ShapedArray):
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
named_shape={}):
super(AbstractSparseArray, self).__init__(shape, dtype)
super().__init__(shape, dtype)
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)

View File

@ -42,7 +42,7 @@ config.parse_flags_with_absl()
class ShardedJitTest(jtu.JaxTestCase):
def setUp(self):
super(ShardedJitTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest
if jtu.device_under_test() == "gpu":
@ -279,7 +279,7 @@ class ShardedJitTest(jtu.JaxTestCase):
class ShardedJitErrorsTest(jtu.JaxTestCase):
def setUp(self):
super(ShardedJitErrorsTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest
@ -329,7 +329,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
class PmapOfShardedJitTest(jtu.JaxTestCase):
def setUp(self):
super(PmapOfShardedJitTest, self).setUp()
super().setUp()
if jtu.device_under_test() not in ["tpu", "gpu"]:
raise SkipTest
if jtu.device_under_test() == "gpu":