mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix code quality issues (#4302)
Changes: - Fix unnecessary generator - Iterate dictionary directly instead of calling .keys() - Remove global statement at the module level - Use list() instead of a list comprehension - Use with statement to open the file - Merge isinstance calls
This commit is contained in:
parent
d74e81cc8b
commit
40e20242db
@ -112,9 +112,9 @@ def download_and_verify_bazel():
|
|||||||
sys.stdout.write("\n")
|
sys.stdout.write("\n")
|
||||||
|
|
||||||
# Verify that the downloaded Bazel binary has the expected SHA256.
|
# Verify that the downloaded Bazel binary has the expected SHA256.
|
||||||
downloaded_file = open(tmp_path, "rb")
|
with open(tmp_path, "rb") as downloaded_file:
|
||||||
contents = downloaded_file.read()
|
contents = downloaded_file.read()
|
||||||
downloaded_file.close()
|
|
||||||
digest = hashlib.sha256(contents).hexdigest()
|
digest = hashlib.sha256(contents).hexdigest()
|
||||||
if digest != package.sha256:
|
if digest != package.sha256:
|
||||||
print(
|
print(
|
||||||
@ -123,9 +123,8 @@ def download_and_verify_bazel():
|
|||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# Write the file as the bazel file name.
|
# Write the file as the bazel file name.
|
||||||
out_file = open(package.file, "wb")
|
with open(package.file, "wb") as out_file:
|
||||||
out_file.write(contents)
|
out_file.write(contents)
|
||||||
out_file.close()
|
|
||||||
|
|
||||||
# Mark the file as executable.
|
# Mark the file as executable.
|
||||||
st = os.stat(package.file)
|
st = os.stat(package.file)
|
||||||
@ -223,15 +222,14 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
|||||||
|
|
||||||
|
|
||||||
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
|
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, **kwargs):
|
||||||
f = open("../.bazelrc", "w")
|
with open("../.bazelrc", "w") as f:
|
||||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||||
if cuda_toolkit_path:
|
if cuda_toolkit_path:
|
||||||
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n"
|
||||||
.format(cuda_toolkit_path=cuda_toolkit_path))
|
.format(cuda_toolkit_path=cuda_toolkit_path))
|
||||||
if cudnn_install_path:
|
if cudnn_install_path:
|
||||||
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n"
|
||||||
.format(cudnn_install_path=cudnn_install_path))
|
.format(cudnn_install_path=cudnn_install_path))
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
|
||||||
BANNER = r"""
|
BANNER = r"""
|
||||||
|
@ -16,7 +16,6 @@ from setuptools import setup
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
global __version__
|
|
||||||
__version__ = None
|
__version__ = None
|
||||||
|
|
||||||
with open('jaxlib/version.py') as f:
|
with open('jaxlib/version.py') as f:
|
||||||
|
@ -67,7 +67,7 @@ def _make_concrete_python_scalar(t, x):
|
|||||||
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
|
np.array(x, dtype=dtypes.python_scalar_dtypes[t]),
|
||||||
weak_type=True)
|
weak_type=True)
|
||||||
|
|
||||||
for t in dtypes.python_scalar_dtypes.keys():
|
for t in dtypes.python_scalar_dtypes:
|
||||||
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
|
||||||
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
ad_util.jaxval_zeros_likers[t] = partial(_zeros_like_python_scalar, t)
|
||||||
|
|
||||||
|
@ -1197,8 +1197,8 @@ def axis_frame(axis_name):
|
|||||||
for frame in reversed(frames):
|
for frame in reversed(frames):
|
||||||
if frame.name == axis_name:
|
if frame.name == axis_name:
|
||||||
return frame
|
return frame
|
||||||
else:
|
|
||||||
raise NameError("unbound axis name: {}".format(axis_name))
|
raise NameError("unbound axis name: {}".format(axis_name))
|
||||||
|
|
||||||
def axis_index(axis_name):
|
def axis_index(axis_name):
|
||||||
"""Return the index along the mapped axis ``axis_name``.
|
"""Return the index along the mapped axis ``axis_name``.
|
||||||
|
@ -785,7 +785,7 @@ def get_num_partitions(*partitions):
|
|||||||
if len(partition_specs) == 0:
|
if len(partition_specs) == 0:
|
||||||
# Everything is specified as replicated (all Nones).
|
# Everything is specified as replicated (all Nones).
|
||||||
return None
|
return None
|
||||||
num_partitions_set = set(np.prod(spec) for spec in partition_specs)
|
num_partitions_set = {np.prod(spec) for spec in partition_specs}
|
||||||
if len(num_partitions_set) > 1:
|
if len(num_partitions_set) > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"All partition specs must use the same number of total partitions, "
|
f"All partition specs must use the same number of total partitions, "
|
||||||
@ -1291,8 +1291,8 @@ class DynamicAxisEnv(list):
|
|||||||
for frame in reversed(self):
|
for frame in reversed(self):
|
||||||
if frame.name == axis_name:
|
if frame.name == axis_name:
|
||||||
return frame
|
return frame
|
||||||
else:
|
|
||||||
assert False
|
raise AssertionError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sizes(self):
|
def sizes(self):
|
||||||
|
@ -278,7 +278,7 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
|
|||||||
ValueError if input devices are inconsistent.
|
ValueError if input devices are inconsistent.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
device, = set(d for d in devices if d is not None) or (None,)
|
device, = {d for d in devices if d is not None} or (None,)
|
||||||
return device
|
return device
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
msg = "primitive arguments must be colocated on the same device, got {}"
|
msg = "primitive arguments must be colocated on the same device, got {}"
|
||||||
|
@ -3125,7 +3125,7 @@ def _concatenate_shape_rule(*operands, **kwargs):
|
|||||||
msg = "All objects to concatenate must be arrays, got {}."
|
msg = "All objects to concatenate must be arrays, got {}."
|
||||||
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
op = next(op for op in operands if not isinstance(op, UnshapedArray))
|
||||||
raise TypeError(msg.format(type(op)))
|
raise TypeError(msg.format(type(op)))
|
||||||
if len(set(operand.ndim for operand in operands)) != 1:
|
if len({operand.ndim for operand in operands}) != 1:
|
||||||
msg = "Cannot concatenate arrays with different ranks, got {}."
|
msg = "Cannot concatenate arrays with different ranks, got {}."
|
||||||
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
|
raise TypeError(msg.format(", ".join(str(o.ndim) for o in operands)))
|
||||||
shapes = np.array([operand.shape for operand in operands])
|
shapes = np.array([operand.shape for operand in operands])
|
||||||
|
@ -2413,8 +2413,7 @@ def associative_scan(fn, elems, reverse=False):
|
|||||||
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
|
results = lowered_fn([odd_elem[:-1] for odd_elem in odd_elems],
|
||||||
[elem[2::2] for elem in elems])
|
[elem[2::2] for elem in elems])
|
||||||
else:
|
else:
|
||||||
results = lowered_fn([odd_elem for odd_elem in odd_elems],
|
results = lowered_fn(list(odd_elems), [elem[2::2] for elem in elems])
|
||||||
[elem[2::2] for elem in elems])
|
|
||||||
|
|
||||||
# The first element of a scan is the same as the first element
|
# The first element of a scan is the same as the first element
|
||||||
# of the original `elems`.
|
# of the original `elems`.
|
||||||
|
@ -167,7 +167,7 @@ def _validate_axis_index_groups(axis_index_groups):
|
|||||||
if any(len(g) != len_0 for g in axis_index_groups):
|
if any(len(g) != len_0 for g in axis_index_groups):
|
||||||
raise ValueError("axis_index_groups must all be the same size")
|
raise ValueError("axis_index_groups must all be the same size")
|
||||||
axis_space = range(len_0 * len(axis_index_groups))
|
axis_space = range(len_0 * len(axis_index_groups))
|
||||||
if set(i for g in axis_index_groups for i in g) != set(axis_space):
|
if {i for g in axis_index_groups for i in g} != set(axis_space):
|
||||||
raise ValueError("axis_index_groups must cover all indices exactly once")
|
raise ValueError("axis_index_groups must cover all indices exactly once")
|
||||||
|
|
||||||
def ppermute(x, axis_name, perm):
|
def ppermute(x, axis_name, perm):
|
||||||
|
@ -263,7 +263,7 @@ def host_id(backend: str = None):
|
|||||||
|
|
||||||
def host_ids(backend: str = None):
|
def host_ids(backend: str = None):
|
||||||
"""Returns a sorted list of all host IDs."""
|
"""Returns a sorted list of all host IDs."""
|
||||||
return sorted(list(set(d.host_id for d in devices(backend))))
|
return sorted({d.host_id for d in devices(backend)})
|
||||||
|
|
||||||
|
|
||||||
def host_count(backend: str = None):
|
def host_count(backend: str = None):
|
||||||
|
@ -176,12 +176,12 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None):
|
|||||||
|
|
||||||
@_wraps(np.fft.fftfreq)
|
@_wraps(np.fft.fftfreq)
|
||||||
def fftfreq(n, d=1.0):
|
def fftfreq(n, d=1.0):
|
||||||
if isinstance(n, list) or isinstance(n, tuple):
|
if isinstance(n, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
"The n argument of jax.numpy.fft.fftfreq only takes an int. "
|
||||||
"Got n = %s." % list(n))
|
"Got n = %s." % list(n))
|
||||||
|
|
||||||
elif isinstance(d, list) or isinstance(d, tuple):
|
elif isinstance(d, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
|
||||||
"Got d = %s." % list(d))
|
"Got d = %s." % list(d))
|
||||||
@ -208,12 +208,12 @@ def fftfreq(n, d=1.0):
|
|||||||
|
|
||||||
@_wraps(np.fft.rfftfreq)
|
@_wraps(np.fft.rfftfreq)
|
||||||
def rfftfreq(n, d=1.0):
|
def rfftfreq(n, d=1.0):
|
||||||
if isinstance(n, list) or isinstance(n, tuple):
|
if isinstance(n, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
"The n argument of jax.numpy.fft.rfftfreq only takes an int. "
|
||||||
"Got n = %s." % list(n))
|
"Got n = %s." % list(n))
|
||||||
|
|
||||||
elif isinstance(d, list) or isinstance(d, tuple):
|
elif isinstance(d, (list, tuple)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
"The d argument of jax.numpy.fft.rfftfreq only takes a single value. "
|
||||||
"Got d = %s." % list(d))
|
"Got d = %s." % list(d))
|
||||||
|
@ -3346,7 +3346,7 @@ def lexsort(keys, axis=-1):
|
|||||||
keys = tuple(keys)
|
keys = tuple(keys)
|
||||||
if len(keys) == 0:
|
if len(keys) == 0:
|
||||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||||
if len(set(shape(key) for key in keys)) > 1:
|
if len({shape(key) for key in keys}) > 1:
|
||||||
raise ValueError("all keys need to be the same shape")
|
raise ValueError("all keys need to be the same shape")
|
||||||
if ndim(keys[0]) == 0:
|
if ndim(keys[0]) == 0:
|
||||||
return np.int64(0)
|
return np.int64(0)
|
||||||
@ -3769,7 +3769,7 @@ def _index_to_gather(x_shape, idx):
|
|||||||
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
|
idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
|
||||||
advanced_pairs = (
|
advanced_pairs = (
|
||||||
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
(asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
|
||||||
if (isinstance(e, Sequence) or isinstance(e, ndarray)))
|
if isinstance(e, (Sequence, ndarray)))
|
||||||
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
|
||||||
for e, i, j in advanced_pairs)
|
for e, i, j in advanced_pairs)
|
||||||
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
|
||||||
@ -3838,8 +3838,7 @@ def _index_to_gather(x_shape, idx):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
abstract_i = None
|
abstract_i = None
|
||||||
# Handle basic int indexes.
|
# Handle basic int indexes.
|
||||||
if (isinstance(abstract_i, ConcreteArray) or
|
if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
|
||||||
isinstance(abstract_i, ShapedArray)) and _int(abstract_i):
|
|
||||||
if x_shape[x_axis] == 0:
|
if x_shape[x_axis] == 0:
|
||||||
# XLA gives error when indexing into an axis of size 0
|
# XLA gives error when indexing into an axis of size 0
|
||||||
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
|
||||||
@ -3939,8 +3938,8 @@ def _index_to_gather(x_shape, idx):
|
|||||||
def _should_unpack_list_index(x):
|
def _should_unpack_list_index(x):
|
||||||
"""Helper for _eliminate_deprecated_list_indexing."""
|
"""Helper for _eliminate_deprecated_list_indexing."""
|
||||||
return (isinstance(x, ndarray) and np.ndim(x) != 0
|
return (isinstance(x, ndarray) and np.ndim(x) != 0
|
||||||
or isinstance(x, Sequence)
|
or isinstance(x, (Sequence, slice))
|
||||||
or isinstance(x, slice) or x is Ellipsis or x is None)
|
or x is Ellipsis or x is None)
|
||||||
|
|
||||||
def _eliminate_deprecated_list_indexing(idx):
|
def _eliminate_deprecated_list_indexing(idx):
|
||||||
# "Basic slicing is initiated if the selection object is a non-array,
|
# "Basic slicing is initiated if the selection object is a non-array,
|
||||||
|
@ -140,7 +140,7 @@ def _normalize_tolerance(tol):
|
|||||||
if isinstance(tol, dict):
|
if isinstance(tol, dict):
|
||||||
return {np.dtype(k): v for k, v in tol.items()}
|
return {np.dtype(k): v for k, v in tol.items()}
|
||||||
else:
|
else:
|
||||||
return {k: tol for k in _default_tolerance.keys()}
|
return {k: tol for k in _default_tolerance}
|
||||||
|
|
||||||
def join_tolerance(tol1, tol2):
|
def join_tolerance(tol1, tol2):
|
||||||
tol1 = _normalize_tolerance(tol1)
|
tol1 = _normalize_tolerance(tol1)
|
||||||
|
@ -94,7 +94,7 @@ def jax_to_hlo(fn, input_shapes, constants=None):
|
|||||||
if not constants:
|
if not constants:
|
||||||
constants = {}
|
constants = {}
|
||||||
|
|
||||||
overlapping_args = set(arg_name for arg_name, _ in input_shapes) & set(
|
overlapping_args = {arg_name for arg_name, _ in input_shapes} & set(
|
||||||
constants.keys())
|
constants.keys())
|
||||||
if overlapping_args:
|
if overlapping_args:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
1
setup.py
1
setup.py
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
global __version__
|
|
||||||
__version__ = None
|
__version__ = None
|
||||||
|
|
||||||
with open('jax/version.py') as f:
|
with open('jax/version.py') as f:
|
||||||
|
@ -329,7 +329,7 @@ class BatchingTest(jtu.JaxTestCase):
|
|||||||
# test modeling the code in https://github.com/google/jax/issues/54
|
# test modeling the code in https://github.com/google/jax/issues/54
|
||||||
|
|
||||||
def func(xs):
|
def func(xs):
|
||||||
return jnp.array([x for x in xs])
|
return jnp.array(list(xs))
|
||||||
|
|
||||||
xs = jnp.ones((5, 1))
|
xs = jnp.ones((5, 1))
|
||||||
jacrev(func)(xs) # don't crash
|
jacrev(func)(xs) # don't crash
|
||||||
|
Loading…
x
Reference in New Issue
Block a user