Merge pull request #18539 from NeilGirdhar:ruff

PiperOrigin-RevId: 583105786
This commit is contained in:
jax authors 2023-11-16 11:15:19 -08:00
commit 7657a0fb15
50 changed files with 158 additions and 138 deletions

View File

@ -21,10 +21,10 @@ repos:
# only include python files
files: \.py$
- repo: https://github.com/pycqa/flake8
rev: '6.1.0'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.5
hooks:
- id: flake8
- id: ruff
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.6.1'

View File

@ -162,9 +162,10 @@ possible. The `git rebase -i` command might be useful to this end.
### Linting and Type-checking
JAX uses [mypy](https://mypy.readthedocs.io/) and [flake8](https://flake8.pycqa.org/)
to statically test code quality; the easiest way to run these checks locally is via
the [pre-commit](https://pre-commit.com/) framework:
JAX uses [mypy](https://mypy.readthedocs.io/) and
[ruff](https://docs.astral.sh/ruff/) to statically test code quality; the
easiest way to run these checks locally is via the
[pre-commit](https://pre-commit.com/) framework:
```bash
pip install pre-commit

View File

@ -339,20 +339,20 @@ pre-commit run mypy
## Linting
JAX uses the [flake8](https://flake8.pycqa.org/) linter to ensure code quality. You can check
your local changes by running:
JAX uses the [ruff](https://docs.astral.sh/ruff/) linter to ensure code
quality. You can check your local changes by running:
```
pip install flake8
flake8 jax
pip install ruff
ruff jax
```
Alternatively, you can use the [pre-commit](https://pre-commit.com/) framework to run this
on all staged files in your git repository, automatically using the same flake8 version as
on all staged files in your git repository, automatically using the same ruff version as
the GitHub tests:
```
pre-commit run flake8
pre-commit run ruff
```
## Update documentation

View File

@ -55,8 +55,8 @@ from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, ShapedArray
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, apply_flat_fun_nokwargs,
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)

View File

@ -1470,8 +1470,9 @@ def manual_proto(
mesh_shape = list(named_mesh_shape.values())
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
manual_axes = list(sorted(manual_axes_set, key=str))
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
manual_axes = sorted(manual_axes_set, key=str)
replicated_axes = [axis for axis in mesh.axis_names
if axis not in manual_axes_set]
tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])

View File

@ -123,7 +123,7 @@ def print_histogram(histogram: dict[Any, int]):
count_width = max(len(str(v)) for v in histogram.values())
count_fmt = '{:>' + str(count_width) + 'd}'
pairs = [(v, k) for k, v in histogram.items()]
for count, name in reversed(sorted(pairs)):
for count, name in sorted(pairs, reverse=True):
print(count_fmt.format(count), name)

View File

@ -870,7 +870,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64)
for i, grp in enumerate(replica_groups):
grp = list(sorted(grp))
grp = sorted(grp)
for j, (src, dst) in enumerate(perm):
full_perm[i, j, 0] = grp[src]
full_perm[i, j, 1] = grp[dst]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# ruff: noqa: F401
import jaxlib.mlir.ir as ir
import jaxlib.mlir.passmanager as passmanager

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# ruff: noqa: F401
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo

View File

@ -152,4 +152,4 @@ class NDIndexer:
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
other_indexers, _ = partition_list(is_int_indexing, self.indices)
other_shape = [s.size for s in other_indexers] # type: ignore
return tuple((*self.int_indexer_shape, *other_shape))
return (*self.int_indexer_shape, *other_shape)

View File

@ -311,7 +311,7 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
const_avals = map(state.AbstractRef, const_avals)
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
arg_avals = list(var.aval for var in jaxpr.invars)
arg_avals = [var.aval for var in jaxpr.invars]
in_avals = [*merged_const_avals, *arg_avals]
num_consts = len(merged_const_avals)

View File

@ -1306,9 +1306,9 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args,
else:
return None
jaxpr = body_jaxpr.jaxpr
new_invars = tuple((*jaxpr.invars[:body_nconsts],
jaxpr.invars[body_nconsts],
*jaxpr.invars[body_nconsts + 2:]))
new_invars = (*jaxpr.invars[:body_nconsts],
jaxpr.invars[body_nconsts],
*jaxpr.invars[body_nconsts + 2:])
new_outvars = tuple(jaxpr.outvars[2:])
jaxpr = jaxpr.replace(
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
@ -1551,7 +1551,7 @@ def pallas_call_lowering(
print(grid_mapping)
compilation_result = compile_jaxpr(
jaxpr,
tuple((*in_shapes, *out_shapes)),
(*in_shapes, *out_shapes),
grid_mapping,
name,
num_warps,

View File

@ -95,7 +95,7 @@ def dctn(x: Array, type: int = 2,
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = {a: n for a, n in zip(axes, s)}
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)
@ -153,7 +153,7 @@ def idctn(x: Array, type: int = 2,
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = {a: n for a, n in zip(axes, s)}
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)

View File

@ -823,7 +823,7 @@ def _generate_key_paths_(
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
else:
yield key_path, tree # strict leaf type

View File

@ -631,11 +631,11 @@ def backends() -> dict[str, xla_client.Client]:
platform_registrations = list(
zip(platforms, priorities, fail_quietly_list))
else:
platform_registrations = list(
platform_registrations = [
(platform, registration.priority, registration.fail_quietly)
for platform, registration
in _backend_factories.items()
)
]
default_priority = -1000
for platform, priority, fail_quietly in platform_registrations:
try:

View File

@ -1185,7 +1185,7 @@ def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
for d in a.shape:
if is_poly_dim(d):
dim_vars = dim_vars.union(d.get_vars())
return sorted(tuple(dim_vars))
return sorted(dim_vars)
class CachingShapeEvaluator:

View File

@ -3244,13 +3244,13 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool,
"to True.")
raise NotImplementedError(msg)
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
return tuple([tf.linalg.eigvals(operand)])
return (tf.linalg.eigvals(operand),)
elif compute_right_eigenvectors:
return tuple(tf.linalg.eig(operand))
else: # compute_left_eigenvectors == True
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
wHH = tf.math.conj(wH)
return tuple([wHH, vl])
return (wHH, vl)
tf_impl[lax.linalg.eig_p] = _eig

View File

@ -206,7 +206,7 @@ class CompatTestBase(jtu.JaxTestCase):
custom_call_re = r"stablehlo.custom_call\s*@([^\(]+)\("
current_custom_call_targets = sorted(
list(set(re.findall(custom_call_re, module_str))))
set(re.findall(custom_call_re, module_str)))
np.set_printoptions(threshold=sys.maxsize, floatmode="unique")
# Print the current test data to simplify updating the test.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, int32, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, int32, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, complex64

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, uint32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, int32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, int32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, int32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
import datetime
from numpy import array, float32, int32

View File

@ -94,7 +94,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
dim_vars_tuple = tuple(dim_vars)
# All combinations of values
for dim_values in itertools.product(*([(1, 2, 5, 10)] * len(dim_vars_tuple))):
env = {d: dv for d, dv in zip(dim_vars_tuple, dim_values)}
env = dict(zip(dim_vars_tuple, dim_values))
def eval(d: shape_poly.DimSize):
return d.evaluate(env) if core.is_symbolic_dim(d) else d # type: ignore

View File

@ -358,7 +358,7 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
granule_dict = collections.defaultdict(list)
for dev in devices:
granule_dict[getattr(dev, attr)].append(dev)
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
granules = [granule_dict[key] for key in sorted(granule_dict.keys())]
if np.prod(dcn_mesh_shape) != len(granules):
raise ValueError(
f'Number of slices {len(granules)} must equal the product of '

View File

@ -243,9 +243,9 @@ def host_local_array_to_global_array_impl(
arrays = [x.data for x in arr.addressable_shards]
else:
arr = xla.canonicalize_dtype(arr)
arrays = list(
arrays = [
arr[index]
for d, index in local_sharding.devices_indices_map(arr.shape).items())
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
global_aval = _local_to_global_aval(
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
@ -350,9 +350,9 @@ def global_array_to_host_local_array_impl(
else:
# numpy array can show up here during AD.
arr = xla.canonicalize_dtype(arr)
arrays = list(
arrays = [
arr[index]
for d, index in local_sharding.devices_indices_map(arr.shape).items())
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
return pxla.batched_device_put(
local_aval, local_sharding, arrays,
list(global_mesh.local_mesh.devices.flat))

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# ruff: noqa
from jax._src.pjit import (
pjit as pjit,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# ruff: noqa: F401
from jax._src.lib import (
version_str as __version__,
xla_client as xla_client,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# ruff: noqa: F401
from jax._src.xla_bridge import (
default_backend as default_backend,
get_backend as get_backend,

View File

@ -14,8 +14,8 @@
"""Python bindings for the MLIR TPU dialect."""
# flake8: noqa: F401
# flake8: noqa: F403
# ruff: noqa: F401
# ruff: noqa: F403
# pylint: disable=g-bad-import-order
@ -32,7 +32,7 @@ _cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
@_cext.register_operation(_Dialect, replace=True)
class TraceOp(TraceOp):
class TraceOp(TraceOp): # noqa: F405
"""An extension to the automatically generated TraceOp bindings."""
def __init__(self, results, message, level, *, loc=None, ip=None):
@ -45,7 +45,7 @@ class TraceOp(TraceOp):
@_cext.register_operation(_Dialect, replace=True)
class RegionOp(RegionOp):
class RegionOp(RegionOp): # noqa: F405
"""An extension to the automatically generated RegionOp bindings."""
def __init__(self, *, loc=None, ip=None):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
# ruff: noqa: F401
from .mosaic.python import apply_vector_layout # pytype: disable=import-error
from .mosaic.python import infer_memref_layout # pytype: disable=import-error

View File

@ -105,3 +105,85 @@ enable = "c-extension-no-member"
[tool.pylint.format]
indent-string=" "
[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
]
ignore = [
# Unnecessary collection call
"C408",
# Unnecessary map usage
"C417",
# Object names too complex
"C901",
# Local variable is assigned to but never used
"F841",
# Raise with from clause inside except block
"B904",
]
line-length = 88
indent-width = 2
select = [
"B9",
"C",
"F",
"W",
"YTT",
"ASYNC",
"E225",
"E227",
"E228",
]
target-version = "py39"
[tool.ruff.mccabe]
max-complexity = 18
[tool.ruff.per-file-ignores]
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]

View File

@ -1,63 +0,0 @@
[flake8]
max-line-length = 88
ignore =
# object names too complex
C901
# four-space indents
E111, E114
# line continuations
E121
# line breaks around binary operators
W503, W504
max-complexity = 18
select = B,C,F,W,T4,B9,E225,E227,E228
exclude =
.git,
build,
__pycache__
per-file-ignores =
# F811: redefinition of unused name
docs/autodidax.py:F811
# F401: unused imports
# Note: we don't use jax/*.py because this matches contents of jax/_src
__init__.py:F401
jax/abstract_arrays.py:F401
jax/ad_checkpoint.py:F401
jax/api_util.py:F401
jax/cloud_tpu_init.py:F401
jax/core.py:F401
jax/custom_batching.py:F401
jax/custom_derivatives.py:F401
jax/custom_transpose.py:F401
jax/debug.py:F401
jax/distributed.py:F401
jax/dlpack.py:F401
jax/dtypes.py:F401
jax/errors.py:F401
jax/extend/*.py:F401
jax/flatten_util.py:F401
jax/interpreters/ad.py:F401
jax/interpreters/batching.py:F401
jax/interpreters/mlir.py:F401
jax/interpreters/partial_eval.py:F401
jax/interpreters/pxla.py:F401
jax/interpreters/xla.py:F401
jax/linear_util.py:F401
jax/monitoring.py:F401
jax/prng.py:F401
jax/profiler.py:F401
jax/random.py:F401
jax/sharding.py:F401
jax/stages.py:F401
jax/test_util.py:F401
jax/tree_util.py:F401
jax/typing.py:F401
jax/util.py:F401
jax/_src/api.py:F401
jax/_src/numpy/lax_numpy.py:F401
jax/_src/typing.py:F401
jax/experimental/*.py:F401
jax/lax/*.py:F401
jax/nn/*.py:F401
jax/numpy/*.py:F401
jax/scipy/*.py:F401

View File

@ -50,11 +50,10 @@ def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
) == 2, "shape = [num_queries, ground_truth_neighbors_per_query]"
assert result_neighbors.shape[0] == ground_truth_neighbors.shape[0]
gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors]
hits = sum(
len(list(x
for x in nn_per_q
if x.item() in gt_sets[q]))
for q, nn_per_q in enumerate(result_neighbors))
hits = sum(len([x
for x in nn_per_q
if x.item() in gt_sets[q]])
for q, nn_per_q in enumerate(result_neighbors))
return hits / ground_truth_neighbors.size

View File

@ -62,7 +62,7 @@ class SparseArray:
return self.data.shape[0]
def __repr__(self):
return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))
return repr([(tuple(ind), d) for ind, d in zip(self.indices, self.data)])
class AbstractSparseArray(core.ShapedArray):

View File

@ -83,7 +83,7 @@ for effect in _testing_effects.values():
# and just doubles its argument.
testing_primitive_with_effect_p = core.Primitive("testing_primitive_with_effect")
testing_primitive_with_effect_p.def_effectful_abstract_eval(
lambda aval, *x, effect_class_name: (aval, set([_testing_effects[effect_class_name]])))
lambda aval, *x, effect_class_name: (aval, {_testing_effects[effect_class_name]}))
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
if "Ordered" in effect_class_name:

View File

@ -1125,7 +1125,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
def testLuBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
args = [rng(shape, jnp.float32) for _ in range(10)]
expected = list(osp.linalg.lu(x) for x in args)
expected = [osp.linalg.lu(x) for x in args]
ps = np.stack([out[0] for out in expected])
ls = np.stack([out[1] for out in expected])
us = np.stack([out[2] for out in expected])

View File

@ -652,7 +652,7 @@ class TreeTest(jtu.JaxTestCase):
x = ((1, 2), [3, 4, 5])
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
out = tree_util.tree_map_with_path(
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
lambda kp, *xs: (kp[0].idx, *xs), x, y,
is_leaf=lambda n: isinstance(n, list))
self.assertEqual(out, (((0, 1, [3]),
(0, 2, jnp.array(0))),