24 Commits

Author SHA1 Message Date
Peter Hawkins
14cb7453f0 Add a C++ implementation of a toplogical sort.
This is an exact port of the current Python implementation to C++ for speed.

I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.

PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
George Necula
817b3e5757 [better_errors] Continue adding debug info to Jaxprs (step 7)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Dougal
1c9b23c566 Stop using generators for linear_util transformations.
They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
2024-11-13 13:47:07 -08:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Jake VanderPlas
b9c7b9bb4f Remove obsolete jaxlib version checks 2023-07-12 11:53:55 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Peter Hawkins
74384e6a87 Add a C++ safe_zip implementation.
Benchmark results on my workstation:
```
name                                 old cpu/op   new cpu/op   delta
safe_zip/arg_lengths:0/num_args:1    1.22µs ± 1%  0.28µs ± 8%  -77.33%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1    1.28µs ± 1%  0.34µs ± 6%  -73.18%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1    1.28µs ± 1%  0.38µs ± 5%  -70.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1    1.38µs ± 1%  0.51µs ± 3%  -63.26%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1   1.61µs ± 1%  0.69µs ± 3%  -56.93%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1  5.39µs ± 1%  3.83µs ± 2%  -29.03%  (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2    1.46µs ± 1%  0.32µs ± 4%  -78.30%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2    1.52µs ± 1%  0.39µs ± 4%  -74.20%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2    1.53µs ± 1%  0.44µs ± 4%  -71.38%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2    1.66µs ± 2%  0.60µs ± 3%  -63.96%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2   1.90µs ± 1%  0.82µs ± 3%  -56.66%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2  6.51µs ± 1%  4.80µs ± 0%  -26.23%  (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3    1.62µs ± 1%  0.36µs ± 4%  -77.95%  (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3    1.68µs ± 1%  0.44µs ± 3%  -73.75%  (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3    1.69µs ± 1%  0.50µs ± 3%  -70.48%  (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3    1.83µs ± 1%  0.68µs ± 2%  -62.73%  (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3   2.12µs ± 1%  0.96µs ± 1%  -54.71%  (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3  7.34µs ± 2%  5.89µs ± 1%  -19.74%  (p=0.008 n=5+5)
```

In addition, improve the length mismatch error for safe_map and define __module__ on both functions.

PiperOrigin-RevId: 523475834
2023-04-11 12:43:04 -07:00
Matthew Johnson
08ca4ed34d update skipping logic 2023-04-10 21:22:49 -07:00
Peter Hawkins
0dbd467cea Add a C++ implementation of safe map.
Before (argument names reversed, oops, fixed in code):

```
name                                 time/op
safe_map/num_args:0/arg_lengths:1    1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1    1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1    1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1    2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1   2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1  15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2    1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2    1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2    1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2    2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2   3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2  17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3    1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3    2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3    2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3    2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3   3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3  18.1µs ± 1%
```

After:
```
name                                 time/op
safe_map/num_args:0/arg_lengths:1     409ns ± 1%
safe_map/num_args:1/arg_lengths:1     602ns ± 5%
safe_map/num_args:2/arg_lengths:1     777ns ± 4%
safe_map/num_args:5/arg_lengths:1    1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1   1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1  14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2     451ns ± 1%
safe_map/num_args:1/arg_lengths:2     652ns ± 0%
safe_map/num_args:2/arg_lengths:2     850ns ± 4%
safe_map/num_args:5/arg_lengths:2    1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2   2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2  16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3     496ns ± 1%
safe_map/num_args:1/arg_lengths:3     718ns ± 5%
safe_map/num_args:2/arg_lengths:3     919ns ± 4%
safe_map/num_args:5/arg_lengths:3    1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3   2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3  17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Parker Schuh
35a5012eea Fix and add test for weakref_lru_cache asan issue.
PiperOrigin-RevId: 474684516
2022-09-15 16:31:14 -07:00
Parker Schuh
df1c478ec5 Fix race condition for weakref destructor by catching rare exceptions. 2022-04-01 12:04:36 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis (#3304)
* Cleanup: fix lint errors in tests/*.py

* Add flake8 step to travis

* add setup.cfg
2020-06-02 19:25:47 -07:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests (#3276) 2020-06-01 11:49:35 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. (#2969) 2020-05-05 14:59:16 -04:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
4b2a5a1f1b Attempt a fix for PY2 2020-01-05 10:12:31 -08:00
George Necula
df374fa3a2 Removed unused imports 2020-01-05 16:46:29 +01:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00