The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008.
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
Context: If an AssertionError is thrown inside a test and traceback filtering
is enabled, most of the stack-trace is swallowed (due to
https://bugs.python.org/issue24959).
PiperOrigin-RevId: 428729211
[jaxlib] allow empty traceback overwrites
If an error is raised within JAX (under an API boundary frame), but prior to entering any user code, then all frames in between are JAX-internal. In this case, our filtered traceback ought to be trivial, i.e. empty of any frames at all.
Prior to this change, we did not handle this edge case consistently with the non-trivial case: any trivial filtered traceback was modified to comprise a single JAX-internal frame (namely, the inner-most one). With this change, the filtered traceback can be completely empty and result in omission of all JAX-internal frames.
Before:
```
Traceback (most recent call last):
File "tb.py", line 10, in <module>
jit(f)(A())
File "jax/_src/api.py", line 2850, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
TypeError: Argument ... is not a valid JAX type.
```
After:
```
Traceback (most recent call last):
File "tb.py", line 10, in <module>
jit(f)(A())
TypeError: Argument ... is not a valid JAX type.
```
PiperOrigin-RevId: 422962976
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...
PiperOrigin-RevId: 400858477
After this change, the filtered stack trace is attached to the main exception, and the unfiltered stack trace becomes a __cause__ exception.
PiperOrigin-RevId: 371509766
This extra source info is still only on jaxpr staging tracers, but those
seem to be the most common culprits. I moved the `_line_info` attribute
to the base Tracer class in core.py in anticipation of populating it for
more traces than just DynamicJaxprTrace, but I'll leave that extension
to follow-up.
I adapted the main escaped tracer error messages in core.py, and also
slightly generalized and debugged source_info_util functions (thanks for
explaining the path prefix bug, @froystig !).
1. Build on Windows
2. Fix OverflowError
When calling `key = random.PRNGKey(0)` OverflowError: Python int too
large to convert to C long for casting value 4294967295 (0xFFFFFFFF)
from python int to int32.
3. fix file path in regex of errors_test
4. handle ValueError of os.path.commonpath