Fix code quality issues (#4302)

Changes:
- Fix unnecessary generator
- Iterate dictionary directly instead of calling .keys()
- Remove global statement at the module level
- Use list() instead of a list comprehension
- Use with statement to open the file
- Merge isinstance calls
This commit is contained in:
Srijan Saurav 2020-09-17 21:51:18 +05:30 committed by GitHub
parent d74e81cc8b
commit 40e20242db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 36 additions and 42 deletions

View File

@ -112,9 +112,9 @@ def download_and_verify_bazel():
sys.stdout.write("\n")
# Verify that the downloaded Bazel binary has the expected SHA256.
downloaded_file = open(tmp_path, "rb")
contents = downloaded_file.read()
downloaded_file.close()
with open(tmp_path, "rb") as downloaded_file:
contents = downloaded_file.read()
digest = hashlib.sha256(contents).hexdigest()
if digest != package.sha256:
print(
@ -123,9 +123,8 @@ def download_and_verify_bazel():
sys.exit(-1)
# Write the file as the bazel file name.
out_file = open(package.file, "wb")
out_file.write(contents)
out_file.close()
with open(package.file, "wb") as out_file:
out_file.write(contents)
# Mark the file as executable.
st = os.stat(package.file)
@ -223,15 +222,14 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
f = open("../.bazelrc", "w")
f.write(BAZELRC_TEMPLATE.format(**kwargs))
if cuda_toolkit_path:
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
.format(cuda_toolkit_path=cuda_toolkit_path))
if cudnn_install_path:
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
.format(cudnn_install_path=cudnn_install_path))
f.close()
with open("../.bazelrc", "w") as f:
f.write(BAZELRC_TEMPLATE.format(**kwargs))
if cuda_toolkit_path:
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
.format(cuda_toolkit_path=cuda_toolkit_path))
if cudnn_install_path:
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
.format(cudnn_install_path=cudnn_install_path))
BANNER = r"""

View File

@ -16,7 +16,6 @@ from setuptools import setup
from glob import glob
import os
global __version__
__version__ = None
with open('jaxlib/version.py') as f:

View File

@ -67,7 +67,7 @@ def _make_concrete_python_scalar(t, x):
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
weak_type=True)
for t in dtypes.python_scalar_dtypes.keys():
for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)

View File

@ -1197,8 +1197,8 @@ def axis_frame(axis_name):
for frame in reversed(frames):
if frame.name == axis_name:
return frame
else:
raise NameError("unbound axis name: {}".format(axis_name))
raise NameError("unbound axis name: {}".format(axis_name))
def axis_index(axis_name):
"""Return the index along the mapped axis ``axis_name``.

View File

@ -785,7 +785,7 @@ def get_num_partitions(*partitions):
if len(partition_specs) == 0:
# Everything is specified as replicated (all Nones).
return None
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
num_partitions_set = {np.prod(spec) for spec in partition_specs}
if len(num_partitions_set) > 1:
raise ValueError(
f"All partition specs must use the same number of total partitions, "
@ -1291,8 +1291,8 @@ class DynamicAxisEnv(list):
for frame in reversed(self):
if frame.name == axis_name:
return frame
else:
assert False
raise AssertionError
@property
def sizes(self):

View File

@ -278,7 +278,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
ValueError if input devices are inconsistent.
"""
try:
device, = set(d for d in devices if d is not None) or (None,)
device, = {d for d in devices if d is not None} or (None,)
return device
except ValueError as err:
msg = "primitive arguments must be colocated on the same device, got {}"

View File

@ -3125,7 +3125,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
msg = "All objects to concatenate must be arrays, got {}."
op = next(op for op in operands if not isinstance(op, UnshapedArray))
raise TypeError(msg.format(type(op)))
if len(set(operand.ndim for operand in operands)) != 1:
if len({operand.ndim for operand in operands}) != 1:
msg = "Cannot concatenate arrays with different ranks, got {}."
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
shapes = np.array([operand.shape for operand in operands])

View File

@ -2413,8 +2413,7 @@ def associative_scan(fn, elems, reverse=False):
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
[elem[2::2] for elem in elems])
else:
results = lowered_fn([odd_elem for odd_elem in odd_elems],
[elem[2::2] for elem in elems])
results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems])
# The first element of a scan is the same as the first element
# of the original `elems`.

View File

@ -167,7 +167,7 @@ def _validate_axis_index_groups(axis_index_groups):
if any(len(g) != len_0 for g in axis_index_groups):
raise ValueError("axis_index_groups must all be the same size")
axis_space = range(len_0 * len(axis_index_groups))
if set(i for g in axis_index_groups for i in g) != set(axis_space):
if {i for g in axis_index_groups for i in g} != set(axis_space):
raise ValueError("axis_index_groups must cover all indices exactly once")
def ppermute(x, axis_name, perm):

View File

@ -263,7 +263,7 @@ def host_id(backend: str = None):
def host_ids(backend: str = None):
"""Returns a sorted list of all host IDs."""
return sorted(list(set(d.host_id for d in devices(backend))))
return sorted({d.host_id for d in devices(backend)})
def host_count(backend: str = None):

View File

@ -176,12 +176,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):
@_wraps(np.fft.fftfreq)
def fftfreq(n, d=1.0):
if isinstance(n, list) or isinstance(n, tuple):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, list) or isinstance(d, tuple):
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))
@ -208,12 +208,12 @@ def fftfreq(n, d=1.0):
@_wraps(np.fft.rfftfreq)
def rfftfreq(n, d=1.0):
if isinstance(n, list) or isinstance(n, tuple):
if isinstance(n, (list, tuple)):
raise ValueError(
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
"Got n = %s." % list(n))
elif isinstance(d, list) or isinstance(d, tuple):
elif isinstance(d, (list, tuple)):
raise ValueError(
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
"Got d = %s." % list(d))

View File

@ -3346,7 +3346,7 @@ def lexsort(keys, axis=-1):
keys = tuple(keys)
if len(keys) == 0:
raise TypeError("need sequence of keys with len > 0 in lexsort")
if len(set(shape(key) for key in keys)) > 1:
if len({shape(key) for key in keys}) > 1:
raise ValueError("all keys need to be the same shape")
if ndim(keys[0]) == 0:
return np.int64(0)
@ -3769,7 +3769,7 @@ def _index_to_gather(x_shape, idx):
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
advanced_pairs = (
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
if (isinstance(e, Sequence) or isinstance(e, ndarray)))
if isinstance(e, (Sequence, ndarray)))
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
for e, i, j in advanced_pairs)
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
@ -3838,8 +3838,7 @@ def _index_to_gather(x_shape, idx):
except TypeError:
abstract_i = None
# Handle basic int indexes.
if (isinstance(abstract_i, ConcreteArray) or
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
if x_shape[x_axis] == 0:
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
@ -3939,8 +3938,8 @@ def _index_to_gather(x_shape, idx):
def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
return (isinstance(x, ndarray) and np.ndim(x) != 0
or isinstance(x, Sequence)
or isinstance(x, slice) or x is Ellipsis or x is None)
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
def _eliminate_deprecated_list_indexing(idx):
# "Basic slicing is initiated if the selection object is a non-array,

View File

@ -140,7 +140,7 @@ def _normalize_tolerance(tol):
if isinstance(tol, dict):
return {np.dtype(k): v for k, v in tol.items()}
else:
return {k: tol for k in _default_tolerance.keys()}
return {k: tol for k in _default_tolerance}
def join_tolerance(tol1, tol2):
tol1 = _normalize_tolerance(tol1)

View File

@ -94,7 +94,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
if not constants:
constants = {}
overlapping_args = set(arg_name for arg_name, _ in input_shapes) & set(
overlapping_args = {arg_name for arg_name, _ in input_shapes} & set(
constants.keys())
if overlapping_args:
raise ValueError(

View File

@ -14,7 +14,6 @@
from setuptools import setup, find_packages
global __version__
__version__ = None
with open('jax/version.py') as f:

View File

@ -329,7 +329,7 @@ class BatchingTest(jtu.JaxTestCase):
# test modeling the code in https://github.com/google/jax/issues/54
def func(xs):
return jnp.array([x for x in xs])
return jnp.array(list(xs))
xs = jnp.ones((5, 1))
jacrev(func)(xs) # don't crash