diff --git a/build/build.py b/build/build.py index 7192b86e2..333996d22 100755 --- a/build/build.py +++ b/build/build.py @@ -112,9 +112,9 @@ def download_and_verify_bazel(): sys.stdout.write("\n") # Verify that the downloaded Bazel binary has the expected SHA256. - downloaded_file = open(tmp_path, "rb") - contents = downloaded_file.read() - downloaded_file.close() + with open(tmp_path, "rb") as downloaded_file: + contents = downloaded_file.read() + digest = hashlib.sha256(contents).hexdigest() if digest != package.sha256: print( @@ -123,9 +123,8 @@ def download_and_verify_bazel(): sys.exit(-1) # Write the file as the bazel file name. - out_file = open(package.file, "wb") - out_file.write(contents) - out_file.close() + with open(package.file, "wb") as out_file: + out_file.write(contents) # Mark the file as executable. st = os.stat(package.file) @@ -223,15 +222,14 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs): - f = open("../.bazelrc", "w") - f.write(BAZELRC_TEMPLATE.format(**kwargs)) - if cuda_toolkit_path: - f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n" - .format(cuda_toolkit_path=cuda_toolkit_path)) - if cudnn_install_path: - f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n" - .format(cudnn_install_path=cudnn_install_path)) - f.close() + with open("../.bazelrc", "w") as f: + f.write(BAZELRC_TEMPLATE.format(**kwargs)) + if cuda_toolkit_path: + f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n" + .format(cuda_toolkit_path=cuda_toolkit_path)) + if cudnn_install_path: + f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n" + .format(cudnn_install_path=cudnn_install_path)) BANNER = r""" diff --git a/build/setup.py b/build/setup.py index 7976ab2ad..644ff857a 100644 --- a/build/setup.py +++ b/build/setup.py @@ -16,7 +16,6 @@ from setuptools import setup from glob import glob import os -global __version__ __version__ = None with open('jaxlib/version.py') as f: diff --git a/jax/abstract_arrays.py b/jax/abstract_arrays.py index 7b98e8b5a..3898818cf 100644 --- a/jax/abstract_arrays.py +++ b/jax/abstract_arrays.py @@ -67,7 +67,7 @@ def _make_concrete_python_scalar(t, x): np.array(x, dtype=dtypes.python_scalar_dtypes[t]), weak_type=True) -for t in dtypes.python_scalar_dtypes.keys(): +for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t) diff --git a/jax/core.py b/jax/core.py index e805fb7d7..bec3384ca 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1197,8 +1197,8 @@ def axis_frame(axis_name): for frame in reversed(frames): if frame.name == axis_name: return frame - else: - raise NameError("unbound axis name: {}".format(axis_name)) + + raise NameError("unbound axis name: {}".format(axis_name)) def axis_index(axis_name): """Return the index along the mapped axis ``axis_name``. diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 5821ec2b4..9627a5e4d 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -785,7 +785,7 @@ def get_num_partitions(*partitions): if len(partition_specs) == 0: # Everything is specified as replicated (all Nones). return None - num_partitions_set = set(np.prod(spec) for spec in partition_specs) + num_partitions_set = {np.prod(spec) for spec in partition_specs} if len(num_partitions_set) > 1: raise ValueError( f"All partition specs must use the same number of total partitions, " @@ -1291,8 +1291,8 @@ class DynamicAxisEnv(list): for frame in reversed(self): if frame.name == axis_name: return frame - else: - assert False + + raise AssertionError @property def sizes(self): diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 01d07f3df..22e053fbc 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -278,7 +278,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De ValueError if input devices are inconsistent. """ try: - device, = set(d for d in devices if d is not None) or (None,) + device, = {d for d in devices if d is not None} or (None,) return device except ValueError as err: msg = "primitive arguments must be colocated on the same device, got {}" diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 8e4d16148..d81a0eec2 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -3125,7 +3125,7 @@ def _concatenate_shape_rule(*operands, **kwargs): msg = "All objects to concatenate must be arrays, got {}." op = next(op for op in operands if not isinstance(op, UnshapedArray)) raise TypeError(msg.format(type(op))) - if len(set(operand.ndim for operand in operands)) != 1: + if len({operand.ndim for operand in operands}) != 1: msg = "Cannot concatenate arrays with different ranks, got {}." raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands))) shapes = np.array([operand.shape for operand in operands]) diff --git a/jax/lax/lax_control_flow.py b/jax/lax/lax_control_flow.py index eb4850a1d..bd838afc6 100644 --- a/jax/lax/lax_control_flow.py +++ b/jax/lax/lax_control_flow.py @@ -2413,8 +2413,7 @@ def associative_scan(fn, elems, reverse=False): results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems], [elem[2::2] for elem in elems]) else: - results = lowered_fn([odd_elem for odd_elem in odd_elems], - [elem[2::2] for elem in elems]) + results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems]) # The first element of a scan is the same as the first element # of the original `elems`. diff --git a/jax/lax/lax_parallel.py b/jax/lax/lax_parallel.py index 5d22d68e7..b294e60a2 100644 --- a/jax/lax/lax_parallel.py +++ b/jax/lax/lax_parallel.py @@ -167,7 +167,7 @@ def _validate_axis_index_groups(axis_index_groups): if any(len(g) != len_0 for g in axis_index_groups): raise ValueError("axis_index_groups must all be the same size") axis_space = range(len_0 * len(axis_index_groups)) - if set(i for g in axis_index_groups for i in g) != set(axis_space): + if {i for g in axis_index_groups for i in g} != set(axis_space): raise ValueError("axis_index_groups must cover all indices exactly once") def ppermute(x, axis_name, perm): diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index f5dc96aef..1c25a0900 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -263,7 +263,7 @@ def host_id(backend: str = None): def host_ids(backend: str = None): """Returns a sorted list of all host IDs.""" - return sorted(list(set(d.host_id for d in devices(backend)))) + return sorted({d.host_id for d in devices(backend)}) def host_count(backend: str = None): diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 07cd71cda..00b656369 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -176,12 +176,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None): @_wraps(np.fft.fftfreq) def fftfreq(n, d=1.0): - if isinstance(n, list) or isinstance(n, tuple): + if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " "Got n = %s." % list(n)) - elif isinstance(d, list) or isinstance(d, tuple): + elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) @@ -208,12 +208,12 @@ def fftfreq(n, d=1.0): @_wraps(np.fft.rfftfreq) def rfftfreq(n, d=1.0): - if isinstance(n, list) or isinstance(n, tuple): + if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.rfftfreq only takes an int. " "Got n = %s." % list(n)) - elif isinstance(d, list) or isinstance(d, tuple): + elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.rfftfreq only takes a single value. " "Got d = %s." % list(d)) diff --git a/jax/numpy/lax_numpy.py b/jax/numpy/lax_numpy.py index 9d3af2e6e..c12af4638 100644 --- a/jax/numpy/lax_numpy.py +++ b/jax/numpy/lax_numpy.py @@ -3346,7 +3346,7 @@ def lexsort(keys, axis=-1): keys = tuple(keys) if len(keys) == 0: raise TypeError("need sequence of keys with len > 0 in lexsort") - if len(set(shape(key) for key in keys)) > 1: + if len({shape(key) for key in keys}) > 1: raise ValueError("all keys need to be the same shape") if ndim(keys[0]) == 0: return np.int64(0) @@ -3769,7 +3769,7 @@ def _index_to_gather(x_shape, idx): idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None] advanced_pairs = ( (asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones) - if (isinstance(e, Sequence) or isinstance(e, ndarray))) + if isinstance(e, (Sequence, ndarray))) advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) for e, i, j in advanced_pairs) advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) @@ -3838,8 +3838,7 @@ def _index_to_gather(x_shape, idx): except TypeError: abstract_i = None # Handle basic int indexes. - if (isinstance(abstract_i, ConcreteArray) or - isinstance(abstract_i, ShapedArray)) and _int(abstract_i): + if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i): if x_shape[x_axis] == 0: # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") @@ -3939,8 +3938,8 @@ def _index_to_gather(x_shape, idx): def _should_unpack_list_index(x): """Helper for _eliminate_deprecated_list_indexing.""" return (isinstance(x, ndarray) and np.ndim(x) != 0 - or isinstance(x, Sequence) - or isinstance(x, slice) or x is Ellipsis or x is None) + or isinstance(x, (Sequence, slice)) + or x is Ellipsis or x is None) def _eliminate_deprecated_list_indexing(idx): # "Basic slicing is initiated if the selection object is a non-array, diff --git a/jax/test_util.py b/jax/test_util.py index c47de7042..30c98a291 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -140,7 +140,7 @@ def _normalize_tolerance(tol): if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: - return {k: tol for k in _default_tolerance.keys()} + return {k: tol for k in _default_tolerance} def join_tolerance(tol1, tol2): tol1 = _normalize_tolerance(tol1) diff --git a/jax/tools/jax_to_hlo.py b/jax/tools/jax_to_hlo.py index 98b58937c..3985dab4e 100644 --- a/jax/tools/jax_to_hlo.py +++ b/jax/tools/jax_to_hlo.py @@ -94,7 +94,7 @@ def jax_to_hlo(fn, input_shapes, constants=None): if not constants: constants = {} - overlapping_args = set(arg_name for arg_name, _ in input_shapes) & set( + overlapping_args = {arg_name for arg_name, _ in input_shapes} & set( constants.keys()) if overlapping_args: raise ValueError( diff --git a/setup.py b/setup.py index 0f758ff44..89cf89e82 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,6 @@ from setuptools import setup, find_packages -global __version__ __version__ = None with open('jax/version.py') as f: diff --git a/tests/batching_test.py b/tests/batching_test.py index 5744387ad..79945f3d8 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -329,7 +329,7 @@ class BatchingTest(jtu.JaxTestCase): # test modeling the code in https://github.com/google/jax/issues/54 def func(xs): - return jnp.array([x for x in xs]) + return jnp.array(list(xs)) xs = jnp.ones((5, 1)) jacrev(func)(xs) # don't crash