Peter Hawkins
14cb7453f0
Add a C++ implementation of a toplogical sort.
...
This is an exact port of the current Python implementation to C++ for speed.
I am being careful not to change the topological order we return in any way in this change, although we may do so in a future change.
PiperOrigin-RevId: 737014989
2025-03-14 16:04:25 -07:00
George Necula
817b3e5757
[better_errors] Continue adding debug info to Jaxprs (step 7)
...
This follows in a series, starting with #26078 and #26313 , adding debug_info to more calls to lu.wrap_init.
Fixes in jet, stateful code, key_reuse, ode, pallas, tests.
2025-02-09 18:14:33 +02:00
Dougal
1c9b23c566
Stop using generators for linear_util transformations.
...
They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
2024-11-13 13:47:07 -08:00
Jake VanderPlas
f090074d86
Avoid 'from jax import config' imports
...
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Jake VanderPlas
b9c7b9bb4f
Remove obsolete jaxlib version checks
2023-07-12 11:53:55 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Peter Hawkins
74384e6a87
Add a C++ safe_zip implementation.
...
Benchmark results on my workstation:
```
name old cpu/op new cpu/op delta
safe_zip/arg_lengths:0/num_args:1 1.22µs ± 1% 0.28µs ± 8% -77.33% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:1 1.28µs ± 1% 0.34µs ± 6% -73.18% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:1 1.28µs ± 1% 0.38µs ± 5% -70.26% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:1 1.38µs ± 1% 0.51µs ± 3% -63.26% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:1 1.61µs ± 1% 0.69µs ± 3% -56.93% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:1 5.39µs ± 1% 3.83µs ± 2% -29.03% (p=0.008 n=5+5)
safe_zip/arg_lengths:0/num_args:2 1.46µs ± 1% 0.32µs ± 4% -78.30% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:2 1.52µs ± 1% 0.39µs ± 4% -74.20% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:2 1.53µs ± 1% 0.44µs ± 4% -71.38% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:2 1.66µs ± 2% 0.60µs ± 3% -63.96% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:2 1.90µs ± 1% 0.82µs ± 3% -56.66% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:2 6.51µs ± 1% 4.80µs ± 0% -26.23% (p=0.016 n=5+4)
safe_zip/arg_lengths:0/num_args:3 1.62µs ± 1% 0.36µs ± 4% -77.95% (p=0.008 n=5+5)
safe_zip/arg_lengths:1/num_args:3 1.68µs ± 1% 0.44µs ± 3% -73.75% (p=0.008 n=5+5)
safe_zip/arg_lengths:2/num_args:3 1.69µs ± 1% 0.50µs ± 3% -70.48% (p=0.008 n=5+5)
safe_zip/arg_lengths:5/num_args:3 1.83µs ± 1% 0.68µs ± 2% -62.73% (p=0.008 n=5+5)
safe_zip/arg_lengths:10/num_args:3 2.12µs ± 1% 0.96µs ± 1% -54.71% (p=0.008 n=5+5)
safe_zip/arg_lengths:100/num_args:3 7.34µs ± 2% 5.89µs ± 1% -19.74% (p=0.008 n=5+5)
```
In addition, improve the length mismatch error for safe_map and define __module__ on both functions.
PiperOrigin-RevId: 523475834
2023-04-11 12:43:04 -07:00
Matthew Johnson
08ca4ed34d
update skipping logic
2023-04-10 21:22:49 -07:00
Peter Hawkins
0dbd467cea
Add a C++ implementation of safe map.
...
Before (argument names reversed, oops, fixed in code):
```
name time/op
safe_map/num_args:0/arg_lengths:1 1.43µs ± 1%
safe_map/num_args:1/arg_lengths:1 1.61µs ± 1%
safe_map/num_args:2/arg_lengths:1 1.72µs ± 0%
safe_map/num_args:5/arg_lengths:1 2.14µs ± 1%
safe_map/num_args:10/arg_lengths:1 2.87µs ± 1%
safe_map/num_args:100/arg_lengths:1 15.6µs ± 1%
safe_map/num_args:0/arg_lengths:2 1.65µs ± 0%
safe_map/num_args:1/arg_lengths:2 1.83µs ± 1%
safe_map/num_args:2/arg_lengths:2 1.97µs ± 1%
safe_map/num_args:5/arg_lengths:2 2.41µs ± 1%
safe_map/num_args:10/arg_lengths:2 3.22µs ± 2%
safe_map/num_args:100/arg_lengths:2 17.0µs ± 2%
safe_map/num_args:0/arg_lengths:3 1.83µs ± 1%
safe_map/num_args:1/arg_lengths:3 2.02µs ± 1%
safe_map/num_args:2/arg_lengths:3 2.16µs ± 1%
safe_map/num_args:5/arg_lengths:3 2.63µs ± 1%
safe_map/num_args:10/arg_lengths:3 3.48µs ± 1%
safe_map/num_args:100/arg_lengths:3 18.1µs ± 1%
```
After:
```
name time/op
safe_map/num_args:0/arg_lengths:1 409ns ± 1%
safe_map/num_args:1/arg_lengths:1 602ns ± 5%
safe_map/num_args:2/arg_lengths:1 777ns ± 4%
safe_map/num_args:5/arg_lengths:1 1.21µs ± 3%
safe_map/num_args:10/arg_lengths:1 1.93µs ± 2%
safe_map/num_args:100/arg_lengths:1 14.7µs ± 0%
safe_map/num_args:0/arg_lengths:2 451ns ± 1%
safe_map/num_args:1/arg_lengths:2 652ns ± 0%
safe_map/num_args:2/arg_lengths:2 850ns ± 4%
safe_map/num_args:5/arg_lengths:2 1.32µs ± 3%
safe_map/num_args:10/arg_lengths:2 2.11µs ± 2%
safe_map/num_args:100/arg_lengths:2 16.0µs ± 1%
safe_map/num_args:0/arg_lengths:3 496ns ± 1%
safe_map/num_args:1/arg_lengths:3 718ns ± 5%
safe_map/num_args:2/arg_lengths:3 919ns ± 4%
safe_map/num_args:5/arg_lengths:3 1.43µs ± 2%
safe_map/num_args:10/arg_lengths:3 2.30µs ± 2%
safe_map/num_args:100/arg_lengths:3 17.3µs ± 1%
```
PiperOrigin-RevId: 523263207
2023-04-10 18:09:56 -07:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
Peter Hawkins
320d531521
Increase the minimum jaxlib version to 0.3.22.
...
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Parker Schuh
35a5012eea
Fix and add test for weakref_lru_cache asan issue.
...
PiperOrigin-RevId: 474684516
2022-09-15 16:31:14 -07:00
Parker Schuh
df1c478ec5
Fix race condition for weakref destructor by catching rare exceptions.
2022-04-01 12:04:36 -07:00
Peter Hawkins
db2e91eba2
Move jax.test_util to jax._src.test_util.
...
Add forwarding shims for names used by external clients of JAX in practice.
PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
afce718eb1
Add ability to specify individual test targets
2020-06-29 11:08:57 -07:00
Jake Vanderplas
9ee4ef1107
Cleanup: de-lint tests directory & add flake8 to travis ( #3304 )
...
* Cleanup: fix lint errors in tests/*.py
* Add flake8 step to travis
* add setup.cfg
2020-06-02 19:25:47 -07:00
Jake Vanderplas
bc30597780
Cleanup: remove unused imports in tests ( #3276 )
2020-06-01 11:49:35 -07:00
Peter Hawkins
b543652332
Replace np -> jnp, onp -> np in tests. ( #2969 )
2020-05-05 14:59:16 -04:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. ( #2117 )
...
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
4b2a5a1f1b
Attempt a fix for PY2
2020-01-05 10:12:31 -08:00
George Necula
df374fa3a2
Removed unused imports
2020-01-05 16:46:29 +01:00
George Necula
528a69f32e
Added some more documentation to the linear_util module
...
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00