mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix code quality issues (#4302)
Changes: - Fix unnecessary generator - Iterate dictionary directly instead of calling .keys() - Remove global statement at the module level - Use list() instead of a list comprehension - Use with statement to open the file - Merge isinstance calls
This commit is contained in:
parent
d74e81cc8b
commit
40e20242db
@ -112,9 +112,9 @@ def download_and_verify_bazel():
|
||||
sys.stdout.write("\n")
|
||||
|
||||
# Verify that the downloaded Bazel binary has the expected SHA256.
|
||||
downloaded_file = open(tmp_path, "rb")
|
||||
contents = downloaded_file.read()
|
||||
downloaded_file.close()
|
||||
with open(tmp_path, "rb") as downloaded_file:
|
||||
contents = downloaded_file.read()
|
||||
|
||||
digest = hashlib.sha256(contents).hexdigest()
|
||||
if digest != package.sha256:
|
||||
print(
|
||||
@ -123,9 +123,8 @@ def download_and_verify_bazel():
|
||||
sys.exit(-1)
|
||||
|
||||
# Write the file as the bazel file name.
|
||||
out_file = open(package.file, "wb")
|
||||
out_file.write(contents)
|
||||
out_file.close()
|
||||
with open(package.file, "wb") as out_file:
|
||||
out_file.write(contents)
|
||||
|
||||
# Mark the file as executable.
|
||||
st = os.stat(package.file)
|
||||
@ -223,15 +222,14 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
|
||||
|
||||
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
|
||||
f = open("../.bazelrc", "w")
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
if cuda_toolkit_path:
|
||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||
.format(cuda_toolkit_path=cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
||||
.format(cudnn_install_path=cudnn_install_path))
|
||||
f.close()
|
||||
with open("../.bazelrc", "w") as f:
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
if cuda_toolkit_path:
|
||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||
.format(cuda_toolkit_path=cuda_toolkit_path))
|
||||
if cudnn_install_path:
|
||||
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
||||
.format(cudnn_install_path=cudnn_install_path))
|
||||
|
||||
|
||||
BANNER = r"""
|
||||
|
@ -16,7 +16,6 @@ from setuptools import setup
|
||||
from glob import glob
|
||||
import os
|
||||
|
||||
global __version__
|
||||
__version__ = None
|
||||
|
||||
with open('jaxlib/version.py') as f:
|
||||
|
@ -67,7 +67,7 @@ def _make_concrete_python_scalar(t, x):
|
||||
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
|
||||
weak_type=True)
|
||||
|
||||
for t in dtypes.python_scalar_dtypes.keys():
|
||||
for t in dtypes.python_scalar_dtypes:
|
||||
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
||||
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
||||
|
||||
|
@ -1197,8 +1197,8 @@ def axis_frame(axis_name):
|
||||
for frame in reversed(frames):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
raise NameError("unbound axis name: {}".format(axis_name))
|
||||
|
||||
raise NameError("unbound axis name: {}".format(axis_name))
|
||||
|
||||
def axis_index(axis_name):
|
||||
"""Return the index along the mapped axis ``axis_name``.
|
||||
|
@ -785,7 +785,7 @@ def get_num_partitions(*partitions):
|
||||
if len(partition_specs) == 0:
|
||||
# Everything is specified as replicated (all Nones).
|
||||
return None
|
||||
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
|
||||
num_partitions_set = {np.prod(spec) for spec in partition_specs}
|
||||
if len(num_partitions_set) > 1:
|
||||
raise ValueError(
|
||||
f"All partition specs must use the same number of total partitions, "
|
||||
@ -1291,8 +1291,8 @@ class DynamicAxisEnv(list):
|
||||
for frame in reversed(self):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
else:
|
||||
assert False
|
||||
|
||||
raise AssertionError
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
|
@ -278,7 +278,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
||||
ValueError if input devices are inconsistent.
|
||||
"""
|
||||
try:
|
||||
device, = set(d for d in devices if d is not None) or (None,)
|
||||
device, = {d for d in devices if d is not None} or (None,)
|
||||
return device
|
||||
except ValueError as err:
|
||||
msg = "primitive arguments must be colocated on the same device, got {}"
|
||||
|
@ -3125,7 +3125,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
|
||||
msg = "All objects to concatenate must be arrays, got {}."
|
||||
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
||||
raise TypeError(msg.format(type(op)))
|
||||
if len(set(operand.ndim for operand in operands)) != 1:
|
||||
if len({operand.ndim for operand in operands}) != 1:
|
||||
msg = "Cannot concatenate arrays with different ranks, got {}."
|
||||
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
|
||||
shapes = np.array([operand.shape for operand in operands])
|
||||
|
@ -2413,8 +2413,7 @@ def associative_scan(fn, elems, reverse=False):
|
||||
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
|
||||
[elem[2::2] for elem in elems])
|
||||
else:
|
||||
results = lowered_fn([odd_elem for odd_elem in odd_elems],
|
||||
[elem[2::2] for elem in elems])
|
||||
results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems])
|
||||
|
||||
# The first element of a scan is the same as the first element
|
||||
# of the original `elems`.
|
||||
|
@ -167,7 +167,7 @@ def _validate_axis_index_groups(axis_index_groups):
|
||||
if any(len(g) != len_0 for g in axis_index_groups):
|
||||
raise ValueError("axis_index_groups must all be the same size")
|
||||
axis_space = range(len_0 * len(axis_index_groups))
|
||||
if set(i for g in axis_index_groups for i in g) != set(axis_space):
|
||||
if {i for g in axis_index_groups for i in g} != set(axis_space):
|
||||
raise ValueError("axis_index_groups must cover all indices exactly once")
|
||||
|
||||
def ppermute(x, axis_name, perm):
|
||||
|
@ -263,7 +263,7 @@ def host_id(backend: str = None):
|
||||
|
||||
def host_ids(backend: str = None):
|
||||
"""Returns a sorted list of all host IDs."""
|
||||
return sorted(list(set(d.host_id for d in devices(backend))))
|
||||
return sorted({d.host_id for d in devices(backend)})
|
||||
|
||||
|
||||
def host_count(backend: str = None):
|
||||
|
@ -176,12 +176,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):
|
||||
|
||||
@_wraps(np.fft.fftfreq)
|
||||
def fftfreq(n, d=1.0):
|
||||
if isinstance(n, list) or isinstance(n, tuple):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, list) or isinstance(d, tuple):
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
@ -208,12 +208,12 @@ def fftfreq(n, d=1.0):
|
||||
|
||||
@_wraps(np.fft.rfftfreq)
|
||||
def rfftfreq(n, d=1.0):
|
||||
if isinstance(n, list) or isinstance(n, tuple):
|
||||
if isinstance(n, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||
"Got n = %s." % list(n))
|
||||
|
||||
elif isinstance(d, list) or isinstance(d, tuple):
|
||||
elif isinstance(d, (list, tuple)):
|
||||
raise ValueError(
|
||||
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
||||
"Got d = %s." % list(d))
|
||||
|
@ -3346,7 +3346,7 @@ def lexsort(keys, axis=-1):
|
||||
keys = tuple(keys)
|
||||
if len(keys) == 0:
|
||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||
if len(set(shape(key) for key in keys)) > 1:
|
||||
if len({shape(key) for key in keys}) > 1:
|
||||
raise ValueError("all keys need to be the same shape")
|
||||
if ndim(keys[0]) == 0:
|
||||
return np.int64(0)
|
||||
@ -3769,7 +3769,7 @@ def _index_to_gather(x_shape, idx):
|
||||
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
|
||||
advanced_pairs = (
|
||||
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
||||
if (isinstance(e, Sequence) or isinstance(e, ndarray)))
|
||||
if isinstance(e, (Sequence, ndarray)))
|
||||
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
||||
for e, i, j in advanced_pairs)
|
||||
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
||||
@ -3838,8 +3838,7 @@ def _index_to_gather(x_shape, idx):
|
||||
except TypeError:
|
||||
abstract_i = None
|
||||
# Handle basic int indexes.
|
||||
if (isinstance(abstract_i, ConcreteArray) or
|
||||
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
|
||||
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
|
||||
if x_shape[x_axis] == 0:
|
||||
# XLA gives error when indexing into an axis of size 0
|
||||
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
||||
@ -3939,8 +3938,8 @@ def _index_to_gather(x_shape, idx):
|
||||
def _should_unpack_list_index(x):
|
||||
"""Helper for _eliminate_deprecated_list_indexing."""
|
||||
return (isinstance(x, ndarray) and np.ndim(x) != 0
|
||||
or isinstance(x, Sequence)
|
||||
or isinstance(x, slice) or x is Ellipsis or x is None)
|
||||
or isinstance(x, (Sequence, slice))
|
||||
or x is Ellipsis or x is None)
|
||||
|
||||
def _eliminate_deprecated_list_indexing(idx):
|
||||
# "Basic slicing is initiated if the selection object is a non-array,
|
||||
|
@ -140,7 +140,7 @@ def _normalize_tolerance(tol):
|
||||
if isinstance(tol, dict):
|
||||
return {np.dtype(k): v for k, v in tol.items()}
|
||||
else:
|
||||
return {k: tol for k in _default_tolerance.keys()}
|
||||
return {k: tol for k in _default_tolerance}
|
||||
|
||||
def join_tolerance(tol1, tol2):
|
||||
tol1 = _normalize_tolerance(tol1)
|
||||
|
@ -94,7 +94,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
||||
if not constants:
|
||||
constants = {}
|
||||
|
||||
overlapping_args = set(arg_name for arg_name, _ in input_shapes) & set(
|
||||
overlapping_args = {arg_name for arg_name, _ in input_shapes} & set(
|
||||
constants.keys())
|
||||
if overlapping_args:
|
||||
raise ValueError(
|
||||
|
1
setup.py
1
setup.py
@ -14,7 +14,6 @@
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
global __version__
|
||||
__version__ = None
|
||||
|
||||
with open('jax/version.py') as f:
|
||||
|
@ -329,7 +329,7 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
# test modeling the code in https://github.com/google/jax/issues/54
|
||||
|
||||
def func(xs):
|
||||
return jnp.array([x for x in xs])
|
||||
return jnp.array(list(xs))
|
||||
|
||||
xs = jnp.ones((5, 1))
|
||||
jacrev(func)(xs) # don't crash
|
||||
|
Loading…
x
Reference in New Issue
Block a user