36 Commits

Author SHA1 Message Date
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Sergei Lebedev
65d3058944 Migrate a subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Patrick Kidger
5e276d0935 Tracebacks no longer have JAX-internal frames prepended by default 2023-08-03 11:38:38 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
7fd1e2ff47 Split _src/traceback_util.py into its own Bazel target.
Improve its type annotations.

PiperOrigin-RevId: 515376365
2023-03-09 10:33:47 -08:00
Peter Hawkins
bd2500579a Change definition of util.wraps so pytype can understand it.
@curry is opaque to pytype.

Fix a false positive type error that turns up because pytype doesn't really understand that a functools.partial is a kind of Callable.

PiperOrigin-RevId: 513697380
2023-03-02 18:41:52 -08:00
Roy Frostig
c241ae60b1 add blank line, mainly to trigger/test source sync
PiperOrigin-RevId: 506414439
2023-02-01 13:56:29 -08:00
Peter Hawkins
cd84eb10a6 Add a number of missing function cross-references in the docs. 2022-11-07 12:00:26 -05:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
23f6ef4e6b Fix Python 3.11 compatibility problems.
Also needs https://github.com/tensorflow/tensorflow/pull/57085
2022-08-10 19:45:24 +00:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Roy Frostig
312a33e31b [jax] completely truncate trivial filtered tracebacks
[jaxlib] allow empty traceback overwrites

If an error is raised within JAX (under an API boundary frame), but prior to entering any user code, then all frames in between are JAX-internal. In this case, our filtered traceback ought to be trivial, i.e. empty of any frames at all.

Prior to this change, we did not handle this edge case consistently with the non-trivial case: any trivial filtered traceback was modified to comprise a single JAX-internal frame (namely, the inner-most one). With this change, the filtered traceback can be completely empty and result in omission of all JAX-internal frames.

Before:

```
Traceback (most recent call last):
  File "tb.py", line 10, in <module>
    jit(f)(A())
  File "jax/_src/api.py", line 2850, in _check_arg
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument ... is not a valid JAX type.
```

After:

```
Traceback (most recent call last):
  File "tb.py", line 10, in <module>
    jit(f)(A())
TypeError: Argument ... is not a valid JAX type.
```
PiperOrigin-RevId: 422962976
2022-01-19 19:42:02 -08:00
Neil Girdhar
9f46c3c92a Annotate api_boundary 2021-12-22 00:19:31 -05:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
1c9dbd12c4 Remove Python 3.6 compatibility code. 2021-07-29 09:09:02 -04:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
2882286b50 Add a --jax_traceback_filtering flag to control the traceback filtering mode.
Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
2021-06-02 16:25:37 -04:00
Peter Hawkins
2c92bc9155 [JAX] Include pre-transformed stack traces as additional context to JAX exceptions, where present.
PiperOrigin-RevId: 371695248
2021-05-03 07:48:47 -07:00
Peter Hawkins
89a6b6d1a0 Fix stack trace filtering for paths that do not exist.
PiperOrigin-RevId: 371521548
2021-05-01 16:13:07 -07:00
Peter Hawkins
a80aab0a38 Fix cloudpickle breakage.
PiperOrigin-RevId: 371517484
2021-05-01 14:57:41 -07:00
Peter Hawkins
e8c340623c [JAX] Switch the order of the filtered and unfiltered stack traces in exceptions.
After this change, the filtered stack trace is attached to the main exception, and the unfiltered stack trace becomes a __cause__ exception.

PiperOrigin-RevId: 371509766
2021-05-01 12:41:37 -07:00
Peter Hawkins
ccc423d680 [JAX] Enable filtered tracebacks on Python 3.6.
[XLA:Python] Add support for converting a fast-traceback into a Python exception traceback.

Add a helper for building traceback objects on Python 3.6. On Python 3.7+ this can be done by calling the traceback type, and we can in essence backport that implementation to Python 3.6.

Consolidate the py_traceback and traceback modules.

PiperOrigin-RevId: 371193212
2021-04-29 13:37:21 -07:00
Roy Frostig
2fc2ff409a tiny change for source sync
PiperOrigin-RevId: 361249206
2021-03-05 16:32:28 -08:00
Roy Frostig
306fb0bf3d Merge pull request #5958 from jakevdp:glossary
PiperOrigin-RevId: 361240568
2021-03-05 15:54:01 -08:00
Roy Frostig
1c20e70839 test source sync
PiperOrigin-RevId: 359687592
2021-02-25 22:07:03 -08:00
Matthew Johnson
eb61395e6b Merge pull request #5823 from jakevdp:jax-101
PiperOrigin-RevId: 359679823
2021-02-25 21:08:04 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Roy Frostig
897916ffc4 test source sync
PiperOrigin-RevId: 350663594
2021-01-07 16:45:11 -08:00
Roy Frostig
432ef31342 tiny change for source sync
PiperOrigin-RevId: 350379734
2021-01-06 10:23:21 -08:00
Roy Frostig
c699329c54 tiny change to test CI
PiperOrigin-RevId: 348870843
2020-12-23 17:58:31 -08:00
Roy Frostig
6cddc2fa77 tiny change for source sync
PiperOrigin-RevId: 348839503
2020-12-23 13:24:58 -08:00
Cloud Han
a6acce58e0 Build on Windows
1. Build on Windows

2. Fix OverflowError

    When calling `key = random.PRNGKey(0)` OverflowError: Python int too
    large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
    from python int to int32.

3. fix file path in regex of errors_test

4. handle ValueError of os.path.commonpath
2020-11-19 23:33:06 +08:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00