mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #18539 from NeilGirdhar:ruff
PiperOrigin-RevId: 583105786
This commit is contained in:
commit
7657a0fb15
@ -21,10 +21,10 @@ repos:
|
||||
# only include python files
|
||||
files: \.py$
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: '6.1.0'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.1.5
|
||||
hooks:
|
||||
- id: flake8
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v1.6.1'
|
||||
|
@ -162,9 +162,10 @@ possible. The `git rebase -i` command might be useful to this end.
|
||||
|
||||
### Linting and Type-checking
|
||||
|
||||
JAX uses [mypy](https://mypy.readthedocs.io/) and [flake8](https://flake8.pycqa.org/)
|
||||
to statically test code quality; the easiest way to run these checks locally is via
|
||||
the [pre-commit](https://pre-commit.com/) framework:
|
||||
JAX uses [mypy](https://mypy.readthedocs.io/) and
|
||||
[ruff](https://docs.astral.sh/ruff/) to statically test code quality; the
|
||||
easiest way to run these checks locally is via the
|
||||
[pre-commit](https://pre-commit.com/) framework:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
|
@ -339,20 +339,20 @@ pre-commit run mypy
|
||||
|
||||
## Linting
|
||||
|
||||
JAX uses the [flake8](https://flake8.pycqa.org/) linter to ensure code quality. You can check
|
||||
your local changes by running:
|
||||
JAX uses the [ruff](https://docs.astral.sh/ruff/) linter to ensure code
|
||||
quality. You can check your local changes by running:
|
||||
|
||||
```
|
||||
pip install flake8
|
||||
flake8 jax
|
||||
pip install ruff
|
||||
ruff jax
|
||||
```
|
||||
|
||||
Alternatively, you can use the [pre-commit](https://pre-commit.com/) framework to run this
|
||||
on all staged files in your git repository, automatically using the same flake8 version as
|
||||
on all staged files in your git repository, automatically using the same ruff version as
|
||||
the GitHub tests:
|
||||
|
||||
```
|
||||
pre-commit run flake8
|
||||
pre-commit run ruff
|
||||
```
|
||||
|
||||
## Update documentation
|
||||
|
@ -55,8 +55,8 @@ from jax._src import pjit
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.core import eval_jaxpr, ShapedArray
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
||||
argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple, apply_flat_fun_nokwargs,
|
||||
check_callable, debug_info, result_paths, flat_out_axes, debug_info_final)
|
||||
|
@ -1470,8 +1470,9 @@ def manual_proto(
|
||||
mesh_shape = list(named_mesh_shape.values())
|
||||
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
|
||||
|
||||
manual_axes = list(sorted(manual_axes_set, key=str))
|
||||
replicated_axes = list(axis for axis in mesh.axis_names if axis not in manual_axes_set)
|
||||
manual_axes = sorted(manual_axes_set, key=str)
|
||||
replicated_axes = [axis for axis in mesh.axis_names
|
||||
if axis not in manual_axes_set]
|
||||
|
||||
tad_perm = ([axis_order[a] for a in replicated_axes] +
|
||||
[axis_order[a] for a in manual_axes])
|
||||
|
@ -123,7 +123,7 @@ def print_histogram(histogram: dict[Any, int]):
|
||||
count_width = max(len(str(v)) for v in histogram.values())
|
||||
count_fmt = '{:>' + str(count_width) + 'd}'
|
||||
pairs = [(v, k) for k, v in histogram.items()]
|
||||
for count, name in reversed(sorted(pairs)):
|
||||
for count, name in sorted(pairs, reverse=True):
|
||||
print(count_fmt.format(count), name)
|
||||
|
||||
|
||||
|
@ -870,7 +870,7 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
|
||||
|
||||
full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64)
|
||||
for i, grp in enumerate(replica_groups):
|
||||
grp = list(sorted(grp))
|
||||
grp = sorted(grp)
|
||||
for j, (src, dst) in enumerate(perm):
|
||||
full_perm[i, j, 0] = grp[src]
|
||||
full_perm[i, j, 1] = grp[dst]
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# ruff: noqa: F401
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.passmanager as passmanager
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# ruff: noqa: F401
|
||||
import jaxlib.mlir.dialects.builtin as builtin
|
||||
import jaxlib.mlir.dialects.chlo as chlo
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
|
@ -152,4 +152,4 @@ class NDIndexer:
|
||||
is_int_indexing = [not isinstance(i, Slice) for i in self.indices]
|
||||
other_indexers, _ = partition_list(is_int_indexing, self.indices)
|
||||
other_shape = [s.size for s in other_indexers] # type: ignore
|
||||
return tuple((*self.int_indexer_shape, *other_shape))
|
||||
return (*self.int_indexer_shape, *other_shape)
|
||||
|
@ -311,7 +311,7 @@ def _hoist_consts_to_refs(jaxpr: jax_core.Jaxpr) -> jax_core.Jaxpr:
|
||||
const_avals, const_ref_avals = partition_list(is_const_ref, all_const_avals)
|
||||
const_avals = map(state.AbstractRef, const_avals)
|
||||
merged_const_avals = merge_lists(is_const_ref, const_avals, const_ref_avals)
|
||||
arg_avals = list(var.aval for var in jaxpr.invars)
|
||||
arg_avals = [var.aval for var in jaxpr.invars]
|
||||
in_avals = [*merged_const_avals, *arg_avals]
|
||||
num_consts = len(merged_const_avals)
|
||||
|
||||
|
@ -1306,9 +1306,9 @@ def _maybe_pattern_match_fori_loop(ctx: TritonLoweringRuleContext, *args,
|
||||
else:
|
||||
return None
|
||||
jaxpr = body_jaxpr.jaxpr
|
||||
new_invars = tuple((*jaxpr.invars[:body_nconsts],
|
||||
jaxpr.invars[body_nconsts],
|
||||
*jaxpr.invars[body_nconsts + 2:]))
|
||||
new_invars = (*jaxpr.invars[:body_nconsts],
|
||||
jaxpr.invars[body_nconsts],
|
||||
*jaxpr.invars[body_nconsts + 2:])
|
||||
new_outvars = tuple(jaxpr.outvars[2:])
|
||||
jaxpr = jaxpr.replace(
|
||||
eqns=jaxpr.eqns[:eqn_index] + jaxpr.eqns[eqn_index + 1:],
|
||||
@ -1551,7 +1551,7 @@ def pallas_call_lowering(
|
||||
print(grid_mapping)
|
||||
compilation_result = compile_jaxpr(
|
||||
jaxpr,
|
||||
tuple((*in_shapes, *out_shapes)),
|
||||
(*in_shapes, *out_shapes),
|
||||
grid_mapping,
|
||||
name,
|
||||
num_warps,
|
||||
|
@ -95,7 +95,7 @@ def dctn(x: Array, type: int = 2,
|
||||
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = {a: n for a, n in zip(axes, s)}
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
@ -153,7 +153,7 @@ def idctn(x: Array, type: int = 2,
|
||||
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = {a: n for a, n in zip(axes, s)}
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
|
@ -823,7 +823,7 @@ def _generate_key_paths_(
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_(tuple((*key_path, k)), c, is_leaf)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
else:
|
||||
yield key_path, tree # strict leaf type
|
||||
|
||||
|
@ -631,11 +631,11 @@ def backends() -> dict[str, xla_client.Client]:
|
||||
platform_registrations = list(
|
||||
zip(platforms, priorities, fail_quietly_list))
|
||||
else:
|
||||
platform_registrations = list(
|
||||
platform_registrations = [
|
||||
(platform, registration.priority, registration.fail_quietly)
|
||||
for platform, registration
|
||||
in _backend_factories.items()
|
||||
)
|
||||
]
|
||||
default_priority = -1000
|
||||
for platform, priority, fail_quietly in platform_registrations:
|
||||
try:
|
||||
|
@ -1185,7 +1185,7 @@ def all_dim_vars(args_avals: Sequence[core.AbstractValue]) -> Sequence[str]:
|
||||
for d in a.shape:
|
||||
if is_poly_dim(d):
|
||||
dim_vars = dim_vars.union(d.get_vars())
|
||||
return sorted(tuple(dim_vars))
|
||||
return sorted(dim_vars)
|
||||
|
||||
|
||||
class CachingShapeEvaluator:
|
||||
|
@ -3244,13 +3244,13 @@ def _eig(operand: TfVal, compute_left_eigenvectors: bool,
|
||||
"to True.")
|
||||
raise NotImplementedError(msg)
|
||||
elif not (compute_left_eigenvectors or compute_right_eigenvectors):
|
||||
return tuple([tf.linalg.eigvals(operand)])
|
||||
return (tf.linalg.eigvals(operand),)
|
||||
elif compute_right_eigenvectors:
|
||||
return tuple(tf.linalg.eig(operand))
|
||||
else: # compute_left_eigenvectors == True
|
||||
wH, vl = tf.linalg.eig(tf.linalg.adjoint(operand))
|
||||
wHH = tf.math.conj(wH)
|
||||
return tuple([wHH, vl])
|
||||
return (wHH, vl)
|
||||
|
||||
|
||||
tf_impl[lax.linalg.eig_p] = _eig
|
||||
|
@ -206,7 +206,7 @@ class CompatTestBase(jtu.JaxTestCase):
|
||||
|
||||
custom_call_re = r"stablehlo.custom_call\s*@([^\(]+)\("
|
||||
current_custom_call_targets = sorted(
|
||||
list(set(re.findall(custom_call_re, module_str))))
|
||||
set(re.findall(custom_call_re, module_str)))
|
||||
|
||||
np.set_printoptions(threshold=sys.maxsize, floatmode="unique")
|
||||
# Print the current test data to simplify updating the test.
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, int32, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, int32, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, complex64
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, uint32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, int32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, int32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, int32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
import datetime
|
||||
from numpy import array, float32, int32
|
||||
|
@ -94,7 +94,7 @@ class DimExprTest(tf_test_util.JaxToTfTestCase):
|
||||
dim_vars_tuple = tuple(dim_vars)
|
||||
# All combinations of values
|
||||
for dim_values in itertools.product(*([(1, 2, 5, 10)] * len(dim_vars_tuple))):
|
||||
env = {d: dv for d, dv in zip(dim_vars_tuple, dim_values)}
|
||||
env = dict(zip(dim_vars_tuple, dim_values))
|
||||
def eval(d: shape_poly.DimSize):
|
||||
return d.evaluate(env) if core.is_symbolic_dim(d) else d # type: ignore
|
||||
|
||||
|
@ -358,7 +358,7 @@ def create_hybrid_device_mesh(mesh_shape: Sequence[int],
|
||||
granule_dict = collections.defaultdict(list)
|
||||
for dev in devices:
|
||||
granule_dict[getattr(dev, attr)].append(dev)
|
||||
granules = list(granule_dict[key] for key in sorted(granule_dict.keys()))
|
||||
granules = [granule_dict[key] for key in sorted(granule_dict.keys())]
|
||||
if np.prod(dcn_mesh_shape) != len(granules):
|
||||
raise ValueError(
|
||||
f'Number of slices {len(granules)} must equal the product of '
|
||||
|
@ -243,9 +243,9 @@ def host_local_array_to_global_array_impl(
|
||||
arrays = [x.data for x in arr.addressable_shards]
|
||||
else:
|
||||
arr = xla.canonicalize_dtype(arr)
|
||||
arrays = list(
|
||||
arrays = [
|
||||
arr[index]
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items())
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
|
||||
|
||||
global_aval = _local_to_global_aval(
|
||||
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
||||
@ -350,9 +350,9 @@ def global_array_to_host_local_array_impl(
|
||||
else:
|
||||
# numpy array can show up here during AD.
|
||||
arr = xla.canonicalize_dtype(arr)
|
||||
arrays = list(
|
||||
arrays = [
|
||||
arr[index]
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items())
|
||||
for d, index in local_sharding.devices_indices_map(arr.shape).items()]
|
||||
return pxla.batched_device_put(
|
||||
local_aval, local_sharding, arrays,
|
||||
list(global_mesh.local_mesh.devices.flat))
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa
|
||||
# ruff: noqa
|
||||
|
||||
from jax._src.pjit import (
|
||||
pjit as pjit,
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# ruff: noqa: F401
|
||||
from jax._src.lib import (
|
||||
version_str as __version__,
|
||||
xla_client as xla_client,
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# ruff: noqa: F401
|
||||
from jax._src.xla_bridge import (
|
||||
default_backend as default_backend,
|
||||
get_backend as get_backend,
|
||||
|
@ -14,8 +14,8 @@
|
||||
|
||||
"""Python bindings for the MLIR TPU dialect."""
|
||||
|
||||
# flake8: noqa: F401
|
||||
# flake8: noqa: F403
|
||||
# ruff: noqa: F401
|
||||
# ruff: noqa: F403
|
||||
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
@ -32,7 +32,7 @@ _cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
|
||||
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class TraceOp(TraceOp):
|
||||
class TraceOp(TraceOp): # noqa: F405
|
||||
"""An extension to the automatically generated TraceOp bindings."""
|
||||
|
||||
def __init__(self, results, message, level, *, loc=None, ip=None):
|
||||
@ -45,7 +45,7 @@ class TraceOp(TraceOp):
|
||||
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class RegionOp(RegionOp):
|
||||
class RegionOp(RegionOp): # noqa: F405
|
||||
"""An extension to the automatically generated RegionOp bindings."""
|
||||
|
||||
def __init__(self, *, loc=None, ip=None):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
# ruff: noqa: F401
|
||||
|
||||
from .mosaic.python import apply_vector_layout # pytype: disable=import-error
|
||||
from .mosaic.python import infer_memref_layout # pytype: disable=import-error
|
||||
|
@ -105,3 +105,85 @@ enable = "c-extension-no-member"
|
||||
|
||||
[tool.pylint.format]
|
||||
indent-string=" "
|
||||
|
||||
[tool.ruff]
|
||||
preview = true
|
||||
exclude = [
|
||||
".git",
|
||||
"build",
|
||||
"__pycache__",
|
||||
]
|
||||
ignore = [
|
||||
# Unnecessary collection call
|
||||
"C408",
|
||||
# Unnecessary map usage
|
||||
"C417",
|
||||
# Object names too complex
|
||||
"C901",
|
||||
# Local variable is assigned to but never used
|
||||
"F841",
|
||||
# Raise with from clause inside except block
|
||||
"B904",
|
||||
]
|
||||
line-length = 88
|
||||
indent-width = 2
|
||||
select = [
|
||||
"B9",
|
||||
"C",
|
||||
"F",
|
||||
"W",
|
||||
"YTT",
|
||||
"ASYNC",
|
||||
"E225",
|
||||
"E227",
|
||||
"E228",
|
||||
]
|
||||
target-version = "py39"
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
max-complexity = 18
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
# F811: Redefinition of unused name.
|
||||
"docs/autodidax.py" = ["F811"]
|
||||
# Note: we don't use jax/*.py because this matches contents of jax/_src
|
||||
"__init__.py" = ["F401"]
|
||||
"jax/abstract_arrays.py" = ["F401"]
|
||||
"jax/ad_checkpoint.py" = ["F401"]
|
||||
"jax/api_util.py" = ["F401"]
|
||||
"jax/cloud_tpu_init.py" = ["F401"]
|
||||
"jax/core.py" = ["F401"]
|
||||
"jax/custom_batching.py" = ["F401"]
|
||||
"jax/custom_derivatives.py" = ["F401"]
|
||||
"jax/custom_transpose.py" = ["F401"]
|
||||
"jax/debug.py" = ["F401"]
|
||||
"jax/distributed.py" = ["F401"]
|
||||
"jax/dlpack.py" = ["F401"]
|
||||
"jax/dtypes.py" = ["F401"]
|
||||
"jax/errors.py" = ["F401"]
|
||||
"jax/experimental/*.py" = ["F401"]
|
||||
"jax/extend/*.py" = ["F401"]
|
||||
"jax/flatten_util.py" = ["F401"]
|
||||
"jax/interpreters/ad.py" = ["F401"]
|
||||
"jax/interpreters/batching.py" = ["F401"]
|
||||
"jax/interpreters/mlir.py" = ["F401"]
|
||||
"jax/interpreters/partial_eval.py" = ["F401"]
|
||||
"jax/interpreters/pxla.py" = ["F401"]
|
||||
"jax/interpreters/xla.py" = ["F401"]
|
||||
"jax/lax/*.py" = ["F401"]
|
||||
"jax/linear_util.py" = ["F401"]
|
||||
"jax/monitoring.py" = ["F401"]
|
||||
"jax/nn/*.py" = ["F401"]
|
||||
"jax/numpy/*.py" = ["F401"]
|
||||
"jax/prng.py" = ["F401"]
|
||||
"jax/profiler.py" = ["F401"]
|
||||
"jax/random.py" = ["F401"]
|
||||
"jax/scipy/*.py" = ["F401"]
|
||||
"jax/sharding.py" = ["F401"]
|
||||
"jax/stages.py" = ["F401"]
|
||||
"jax/test_util.py" = ["F401"]
|
||||
"jax/tree_util.py" = ["F401"]
|
||||
"jax/typing.py" = ["F401"]
|
||||
"jax/util.py" = ["F401"]
|
||||
# F821: Undefined name.
|
||||
"jax/numpy/__init__.pyi" = ["F821"]
|
||||
|
63
setup.cfg
63
setup.cfg
@ -1,63 +0,0 @@
|
||||
[flake8]
|
||||
max-line-length = 88
|
||||
ignore =
|
||||
# object names too complex
|
||||
C901
|
||||
# four-space indents
|
||||
E111, E114
|
||||
# line continuations
|
||||
E121
|
||||
# line breaks around binary operators
|
||||
W503, W504
|
||||
max-complexity = 18
|
||||
select = B,C,F,W,T4,B9,E225,E227,E228
|
||||
exclude =
|
||||
.git,
|
||||
build,
|
||||
__pycache__
|
||||
per-file-ignores =
|
||||
# F811: redefinition of unused name
|
||||
docs/autodidax.py:F811
|
||||
# F401: unused imports
|
||||
# Note: we don't use jax/*.py because this matches contents of jax/_src
|
||||
__init__.py:F401
|
||||
jax/abstract_arrays.py:F401
|
||||
jax/ad_checkpoint.py:F401
|
||||
jax/api_util.py:F401
|
||||
jax/cloud_tpu_init.py:F401
|
||||
jax/core.py:F401
|
||||
jax/custom_batching.py:F401
|
||||
jax/custom_derivatives.py:F401
|
||||
jax/custom_transpose.py:F401
|
||||
jax/debug.py:F401
|
||||
jax/distributed.py:F401
|
||||
jax/dlpack.py:F401
|
||||
jax/dtypes.py:F401
|
||||
jax/errors.py:F401
|
||||
jax/extend/*.py:F401
|
||||
jax/flatten_util.py:F401
|
||||
jax/interpreters/ad.py:F401
|
||||
jax/interpreters/batching.py:F401
|
||||
jax/interpreters/mlir.py:F401
|
||||
jax/interpreters/partial_eval.py:F401
|
||||
jax/interpreters/pxla.py:F401
|
||||
jax/interpreters/xla.py:F401
|
||||
jax/linear_util.py:F401
|
||||
jax/monitoring.py:F401
|
||||
jax/prng.py:F401
|
||||
jax/profiler.py:F401
|
||||
jax/random.py:F401
|
||||
jax/sharding.py:F401
|
||||
jax/stages.py:F401
|
||||
jax/test_util.py:F401
|
||||
jax/tree_util.py:F401
|
||||
jax/typing.py:F401
|
||||
jax/util.py:F401
|
||||
jax/_src/api.py:F401
|
||||
jax/_src/numpy/lax_numpy.py:F401
|
||||
jax/_src/typing.py:F401
|
||||
jax/experimental/*.py:F401
|
||||
jax/lax/*.py:F401
|
||||
jax/nn/*.py:F401
|
||||
jax/numpy/*.py:F401
|
||||
jax/scipy/*.py:F401
|
@ -50,11 +50,10 @@ def compute_recall(result_neighbors, ground_truth_neighbors) -> float:
|
||||
) == 2, "shape = [num_queries, ground_truth_neighbors_per_query]"
|
||||
assert result_neighbors.shape[0] == ground_truth_neighbors.shape[0]
|
||||
gt_sets = [set(np.asarray(x)) for x in ground_truth_neighbors]
|
||||
hits = sum(
|
||||
len(list(x
|
||||
for x in nn_per_q
|
||||
if x.item() in gt_sets[q]))
|
||||
for q, nn_per_q in enumerate(result_neighbors))
|
||||
hits = sum(len([x
|
||||
for x in nn_per_q
|
||||
if x.item() in gt_sets[q]])
|
||||
for q, nn_per_q in enumerate(result_neighbors))
|
||||
return hits / ground_truth_neighbors.size
|
||||
|
||||
|
||||
|
@ -62,7 +62,7 @@ class SparseArray:
|
||||
return self.data.shape[0]
|
||||
|
||||
def __repr__(self):
|
||||
return repr(list((tuple(ind), d) for ind, d in zip(self.indices, self.data)))
|
||||
return repr([(tuple(ind), d) for ind, d in zip(self.indices, self.data)])
|
||||
|
||||
|
||||
class AbstractSparseArray(core.ShapedArray):
|
||||
|
@ -83,7 +83,7 @@ for effect in _testing_effects.values():
|
||||
# and just doubles its argument.
|
||||
testing_primitive_with_effect_p = core.Primitive("testing_primitive_with_effect")
|
||||
testing_primitive_with_effect_p.def_effectful_abstract_eval(
|
||||
lambda aval, *x, effect_class_name: (aval, set([_testing_effects[effect_class_name]])))
|
||||
lambda aval, *x, effect_class_name: (aval, {_testing_effects[effect_class_name]}))
|
||||
|
||||
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
|
||||
if "Ordered" in effect_class_name:
|
||||
|
@ -1125,7 +1125,7 @@ class ScipyLinalgTest(jtu.JaxTestCase):
|
||||
def testLuBatching(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args = [rng(shape, jnp.float32) for _ in range(10)]
|
||||
expected = list(osp.linalg.lu(x) for x in args)
|
||||
expected = [osp.linalg.lu(x) for x in args]
|
||||
ps = np.stack([out[0] for out in expected])
|
||||
ls = np.stack([out[1] for out in expected])
|
||||
us = np.stack([out[2] for out in expected])
|
||||
|
@ -652,7 +652,7 @@ class TreeTest(jtu.JaxTestCase):
|
||||
x = ((1, 2), [3, 4, 5])
|
||||
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
|
||||
out = tree_util.tree_map_with_path(
|
||||
lambda kp, *xs: tuple((kp[0].idx, *xs)), x, y,
|
||||
lambda kp, *xs: (kp[0].idx, *xs), x, y,
|
||||
is_leaf=lambda n: isinstance(n, list))
|
||||
self.assertEqual(out, (((0, 1, [3]),
|
||||
(0, 2, jnp.array(0))),
|
||||
|
Loading…
x
Reference in New Issue
Block a user