mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Cleanup: switch to new version of super()
This commit is contained in:
parent
03ec444f0a
commit
63a788b4de
@ -1062,7 +1062,7 @@ class ShapedArray(UnshapedArray):
|
||||
array_abstraction_level = 1
|
||||
|
||||
def __init__(self, shape, dtype, weak_type=False, named_shape={}):
|
||||
super(ShapedArray, self).__init__(dtype, weak_type=weak_type)
|
||||
super().__init__(dtype, weak_type=weak_type)
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.named_shape = dict(named_shape)
|
||||
|
||||
@ -1141,8 +1141,8 @@ class ConcreteArray(ShapedArray):
|
||||
array_abstraction_level = 0
|
||||
|
||||
def __init__(self, val, weak_type=False):
|
||||
super(ConcreteArray, self).__init__(np.shape(val), np.result_type(val),
|
||||
weak_type=weak_type)
|
||||
super().__init__(np.shape(val), np.result_type(val),
|
||||
weak_type=weak_type)
|
||||
# Note: canonicalized self.dtype doesn't necessarily match self.val
|
||||
self.val = val
|
||||
assert self.dtype != np.dtype('O'), val
|
||||
|
@ -228,7 +228,7 @@ def wrap_init(f, params={}) -> WrappedFun:
|
||||
class _CacheLocalContext(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
super(_CacheLocalContext, self).__init__()
|
||||
super().__init__()
|
||||
self.most_recent_entry = None
|
||||
|
||||
|
||||
|
@ -879,7 +879,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
# assert core.reset_trace_state()
|
||||
|
||||
def setUp(self):
|
||||
super(JaxTestCase, self).setUp()
|
||||
super().setUp()
|
||||
config.update('jax_enable_checks', True)
|
||||
# We use the adler32 hash for two reasons.
|
||||
# a) it is deterministic run to run, unlike hash() which is randomized.
|
||||
|
@ -61,7 +61,7 @@ all_shapes = nonempty_array_shapes + empty_array_shapes
|
||||
|
||||
class DLPackTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super(DLPackTest, self).setUp()
|
||||
super().setUp()
|
||||
if jtu.device_under_test() == "tpu":
|
||||
self.skipTest("DLPack not supported on TPU")
|
||||
|
||||
@ -194,7 +194,7 @@ class DLPackTest(jtu.JaxTestCase):
|
||||
class CudaArrayInterfaceTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(CudaArrayInterfaceTest, self).setUp()
|
||||
super().setUp()
|
||||
if jtu.device_under_test() != "gpu":
|
||||
self.skipTest("__cuda_array_interface__ is only supported on GPU")
|
||||
|
||||
|
@ -62,7 +62,7 @@ class AbstractSparseArray(core.ShapedArray):
|
||||
|
||||
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
|
||||
named_shape={}):
|
||||
super(AbstractSparseArray, self).__init__(shape, dtype)
|
||||
super().__init__(shape, dtype)
|
||||
self.index_dtype = index_dtype
|
||||
self.nnz = nnz
|
||||
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
|
||||
|
@ -42,7 +42,7 @@ config.parse_flags_with_absl()
|
||||
class ShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ShardedJitTest, self).setUp()
|
||||
super().setUp()
|
||||
if jtu.device_under_test() not in ["tpu", "gpu"]:
|
||||
raise SkipTest
|
||||
if jtu.device_under_test() == "gpu":
|
||||
@ -279,7 +279,7 @@ class ShardedJitTest(jtu.JaxTestCase):
|
||||
class ShardedJitErrorsTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(ShardedJitErrorsTest, self).setUp()
|
||||
super().setUp()
|
||||
if jtu.device_under_test() not in ["tpu", "gpu"]:
|
||||
raise SkipTest
|
||||
|
||||
@ -329,7 +329,7 @@ class ShardedJitTestNoTpu(jtu.JaxTestCase):
|
||||
class PmapOfShardedJitTest(jtu.JaxTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(PmapOfShardedJitTest, self).setUp()
|
||||
super().setUp()
|
||||
if jtu.device_under_test() not in ["tpu", "gpu"]:
|
||||
raise SkipTest
|
||||
if jtu.device_under_test() == "gpu":
|
||||
|
Loading…
x
Reference in New Issue
Block a user