Move flake8 & mypy checks to pre-commit

This commit is contained in:
Jake VanderPlas 2021-02-10 13:44:40 -08:00
parent 97f249b101
commit e159d67e7e
6 changed files with 22 additions and 45 deletions

View File

@ -11,50 +11,16 @@ on:
- master
jobs:
pre-commit:
lint_and_typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: pre-commit/action@v2.0.0
lint_and_typecheck:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install -r build/test-requirements.txt
- name: Lint with flake8
run: |
flake8 .
- name: Type check with mypy
run: |
mypy jax
build:
name: "build ${{ matrix.name-prefix }} (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"
runs-on: ${{ matrix.os }}

View File

@ -8,8 +8,19 @@
# 'pre-commit run --all'
repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.10.0
hooks:
- id: jupytext
args: [--sync]
- repo: https://gitlab.com/pycqa/flake8
rev: '3.8.4'
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.800'
hooks:
- id: mypy
files: jax/
- repo: https://github.com/mwouts/jupytext
rev: v1.10.0
hooks:
- id: jupytext
args: [--sync]

View File

@ -186,7 +186,7 @@ def build_wheel(sources_path, output_path):
for wheel in glob.glob(os.path.join(sources_path, "dist", "*.whl")):
output_file = os.path.join(output_path, os.path.basename(wheel))
sys.stderr.write(f"Output wheel: {output_file}\n\n")
sys.stderr.write(f"To install the newly-built jaxlib wheel, run:\n")
sys.stderr.write("To install the newly-built jaxlib wheel, run:\n")
sys.stderr.write(f" pip install {output_file}\n\n")
shutil.copy(wheel, output_path)

View File

@ -82,7 +82,7 @@ def canonicalize_dtype(dtype):
# Default dtypes corresponding to Python scalars.
python_scalar_dtypes = {
python_scalar_dtypes : dict = {
bool: np.dtype(bool_),
int: np.dtype(int_),
float: np.dtype(float_),

View File

@ -748,7 +748,7 @@ def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
remat_call_p = core.CallPrimitive('remat_call')
remat_call_p: core.Primitive = core.CallPrimitive('remat_call')
remat_call = remat_call_p.bind
remat_call_p.def_impl(core.call_impl)

View File

@ -235,7 +235,7 @@ def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
# TODO(mattjj): remove the Python-side registry when the C++-side registry is
# sufficiently queryable that we can express _replace_nones. That may mean once
# we have a flatten_one function.
_RegistryEntry = collections.namedtuple("RegistryEntry", ["to_iter", "from_iter"])
_RegistryEntry = collections.namedtuple("_RegistryEntry", ["to_iter", "from_iter"])
_registry = {
tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),