33 Commits

Author SHA1 Message Date
Michael Hudgins
2e808f2836 Merge pull request #26279 from MichaelHudgins:tsan-resultstore
PiperOrigin-RevId: 723918760
2025-02-06 14:55:57 +00:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Patrick Kidger
5e276d0935 Tracebacks no longer have JAX-internal frames prepended by default 2023-08-03 11:38:38 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Peter Hawkins
2ba0396ddb Add changes accidentally omitted from
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Peter Hawkins
c657449528 Copybara import of the project:
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:

Migrate more tests from jtu.cases_from_list to jtu.sample_product.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
78cb9f8492 Avoid more direct references to jax._src without imports.
Change in preparation for not exporting jax._src by default.

PiperOrigin-RevId: 469725340
2022-08-24 07:51:28 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Lena Martens
0d9990e4f3 Run all tests with jax_traceback_filtering=off.
Context: If an AssertionError is thrown inside a test and traceback filtering
is enabled, most of the stack-trace is swallowed (due to
https://bugs.python.org/issue24959).
PiperOrigin-RevId: 428729211
2022-02-15 02:42:58 -08:00
Roy Frostig
312a33e31b [jax] completely truncate trivial filtered tracebacks
[jaxlib] allow empty traceback overwrites

If an error is raised within JAX (under an API boundary frame), but prior to entering any user code, then all frames in between are JAX-internal. In this case, our filtered traceback ought to be trivial, i.e. empty of any frames at all.

Prior to this change, we did not handle this edge case consistently with the non-trivial case: any trivial filtered traceback was modified to comprise a single JAX-internal frame (namely, the inner-most one). With this change, the filtered traceback can be completely empty and result in omission of all JAX-internal frames.

Before:

```
Traceback (most recent call last):
  File "tb.py", line 10, in <module>
    jit(f)(A())
  File "jax/_src/api.py", line 2850, in _check_arg
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument ... is not a valid JAX type.
```

After:

```
Traceback (most recent call last):
  File "tb.py", line 10, in <module>
    jit(f)(A())
TypeError: Argument ... is not a valid JAX type.
```
PiperOrigin-RevId: 422962976
2022-01-19 19:42:02 -08:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Roy Frostig
d3615ef279 mark custom derivative calls as API boundaries for stack trace filtering 2021-09-21 16:47:27 -07:00
Peter Hawkins
1c9dbd12c4 Remove Python 3.6 compatibility code. 2021-07-29 09:09:02 -04:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
2882286b50 Add a --jax_traceback_filtering flag to control the traceback filtering mode.
Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython.
2021-06-02 16:25:37 -04:00
Peter Hawkins
2c92bc9155 [JAX] Include pre-transformed stack traces as additional context to JAX exceptions, where present.
PiperOrigin-RevId: 371695248
2021-05-03 07:48:47 -07:00
Peter Hawkins
a80aab0a38 Fix cloudpickle breakage.
PiperOrigin-RevId: 371517484
2021-05-01 14:57:41 -07:00
Peter Hawkins
e8c340623c [JAX] Switch the order of the filtered and unfiltered stack traces in exceptions.
After this change, the filtered stack trace is attached to the main exception, and the unfiltered stack trace becomes a __cause__ exception.

PiperOrigin-RevId: 371509766
2021-05-01 12:41:37 -07:00
Jake VanderPlas
0796bfe6e7 errors: add NonConcreteBooleanIndexError & debugging tips 2021-03-23 11:23:20 -07:00
Jake VanderPlas
e9195ba626 Fix URL in custom errors 2021-03-16 09:10:10 -07:00
Roy Frostig
1283a9654b mark lax higher-order functions for stack trace filtering 2021-02-24 21:16:20 -08:00
Matthew Johnson
886b26ffeb add source line info to more escaped tracer errors
This extra source info is still only on jaxpr staging tracers, but those
seem to be the most common culprits. I moved the `_line_info` attribute
to the base Tracer class in core.py in anticipation of populating it for
more traces than just DynamicJaxprTrace, but I'll leave that extension
to follow-up.

I adapted the main escaped tracer error messages in core.py, and also
slightly generalized and debugged source_info_util functions (thanks for
explaining the path prefix bug, @froystig !).
2021-01-18 19:00:04 -08:00
Cloud Han
a6acce58e0 Build on Windows
1. Build on Windows

2. Fix OverflowError

    When calling `key = random.PRNGKey(0)` OverflowError: Python int too
    large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
    from python int to int32.

3. fix file path in regex of errors_test

4. handle ValueError of os.path.commonpath
2020-11-19 23:33:06 +08:00
Peter Hawkins
13db2c3742 Fix mypy error caused by cyclic import dependency.
Add tests that source_info_util is using the same directory as the jax root module.
2020-11-18 10:20:33 -05:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Roy Frostig
dbca9e682c unrevert #3674 (revert #3791) 2020-08-17 18:13:58 -07:00
Roy Frostig
fa2a0275c8 revert #3674 2020-07-17 15:44:51 -07:00
Roy Frostig
6416ca0e9d append filtered stack traces to error messages raised under transformations 2020-07-16 17:12:09 -07:00