mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
CI: error if docstring rewrite fails
This commit is contained in:
parent
f539c9b9bd
commit
f2222bb1cf
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -88,10 +88,12 @@ jobs:
|
||||
env:
|
||||
JAX_NUM_GENERATED_CASES: ${{ matrix.num_generated_cases }}
|
||||
JAX_ENABLE_X64: ${{ matrix.enable-x64 }}
|
||||
JAX_ENABLE_CHECKS: true
|
||||
run: |
|
||||
pip install -e .
|
||||
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
|
||||
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
pytest -n auto --tb=short tests examples
|
||||
|
||||
|
||||
|
@ -17,6 +17,8 @@ import re
|
||||
import textwrap
|
||||
from typing import Callable, NamedTuple, Optional, Dict, Sequence
|
||||
|
||||
from jax._src.config import config
|
||||
|
||||
_parameter_break = re.compile("\n(?=[A-Za-z_])")
|
||||
_section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE)
|
||||
_numpy_signature_re = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\([\w\W]*?\)$', re.MULTILINE)
|
||||
@ -151,6 +153,8 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
|
||||
if kept_sections:
|
||||
docstr += "\n" + "\n\n".join(kept_sections) + "\n"
|
||||
except:
|
||||
if config.jax_enable_checks:
|
||||
raise
|
||||
docstr = fun.__doc__
|
||||
|
||||
op.__doc__ = docstr
|
||||
|
Loading…
x
Reference in New Issue
Block a user