Lena Martens
8ea85769ea
Checkify: add way to disable categories of errors.
...
By default only user_asserts are lifted into the checked function.
2022-01-18 17:59:50 +00:00
jax authors
6411f8a033
Merge pull request #9184 from jakevdp:unique-nan
...
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
bebe9845a8
Merge pull request #9205 from jakevdp:einsum-tuple
...
PiperOrigin-RevId: 422013671
2022-01-15 02:46:56 -08:00
Yash Katariya
b92db58eaf
Canonicalize parsed partition spec before passing to lower_mesh_computation. Creates a new data structure CanonicalizedParsedPartitionSpec
which strips empty tuples from the end of parsed partitions to canonicalize the specs so that P(None)
and None
for example in in_axis_resources are equivalent.
...
I have been bit by this 3 times and its about time I fix this. This also fixes a bug where fully replicated values are allowed with non-contiguous meshes (in this case P(None) and None) were not equal.
PiperOrigin-RevId: 421918164
2022-01-14 14:52:43 -08:00
Peter Hawkins
b509aae2a2
Split lax_control_flow_test into three separate tests.
...
Split the custom root and custom linear solve tests into separate test files.
Disable two slow custom linear solve tests.
Add a few jit decorators to slow tests in lax_control_flow_test.
PiperOrigin-RevId: 421901487
2022-01-14 13:36:34 -08:00
jax authors
c9169fa0d5
Merge pull request #9189 from gnecula:tf_reduce_window
...
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
77d60cf4dd
einsum: clarify use of precision.
2022-01-14 11:08:13 -08:00
jax authors
7f07f1b86b
Merge pull request #9200 from google:LenaMartens-patch-1
...
PiperOrigin-RevId: 421845558
2022-01-14 09:30:24 -08:00
Lena Martens
f591d0b2e9
Add ensure_compile_time_eval docstring to docs
2022-01-14 11:18:40 +00:00
Jake VanderPlas
bd157cf056
jnp.unique: properly handle NaN values
2022-01-13 15:54:07 -08:00
jax authors
3b374e7dd9
Merge pull request #9127 from jakevdp:searchsorted-complex
...
PiperOrigin-RevId: 421672606
2022-01-13 15:25:41 -08:00
Jake VanderPlas
8ca10ea53f
searchsorted: use correct ordering for complex inputs
2022-01-13 13:45:59 -08:00
jax authors
cd73a4195f
Merge pull request #9178 from jakevdp:sort-corner-cases
...
PiperOrigin-RevId: 421646783
2022-01-13 13:33:02 -08:00
Jake VanderPlas
d8bdd9a19d
lax.sort: regularize handling of -0.0 and -NaN
2022-01-13 13:03:41 -08:00
Yash Katariya
0532a63261
Optimizations for GDA to make creating GDA faster.
...
* Use math to figure out the replica id. Using `_hashed_index` (note that this is a function and not `_HashableIndex` which is a class which does not exist anymore) is 1.5 - 2 times slower than using math. markdaoust@ helped with the math here (going to office has its own perks :) )
* Get rid of `_HashableIndex` class and replace it with a function `_hashed_index`. Dataclass is extremely slow.
* Only calculate global_mesh.local_devices once. Even though its a cached property (but its after python 3.8)
```
name old time/op new time/op delta
gda_construction_callback_(4, 2)_['x', 'y'] 4.77ms ± 5% 4.74ms ± 5% ~ (p=0.316 n=14+17)
gda_construction_raw_(256, 8)_['x', 'y'] 17.9ms ± 5% 9.0ms ± 2% -49.92% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x', 'y'] 11.4ms ± 2% 2.9ms ± 2% -74.52% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[None] 34.0ms ±20% 30.5ms ± 2% ~ (p=0.413 n=5+4)
gda_construction_raw_(256, 8)_[None] 15.9ms ± 2% 7.7ms ± 3% -51.56% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[None] 9.39ms ± 3% 1.74ms ± 2% -81.44% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['x'] 8.87ms ± 2% 8.92ms ± 3% ~ (p=0.841 n=5+5)
gda_construction_raw_(256, 8)_['x'] 16.4ms ± 2% 7.7ms ± 1% -52.66% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['x'] 9.85ms ± 1% 1.90ms ± 2% -80.68% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_['y'] 15.9ms ± 3% 16.0ms ± 5% ~ (p=0.690 n=5+5)
gda_construction_raw_(256, 8)_['y'] 15.8ms ± 3% 7.6ms ± 1% -52.04% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_['y'] 9.29ms ± 1% 1.78ms ± 1% -80.79% (p=0.008 n=5+5)
gda_construction_callback_(4, 2)_[('x', 'y')] 4.65ms ± 2% 4.62ms ± 3% ~ (p=0.440 n=5+10)
gda_construction_raw_(256, 8)_[('x', 'y')] 18.6ms ± 3% 9.7ms ± 5% -47.76% (p=0.008 n=5+5)
indices_replica_id_calc_(256, 8)_[('x', 'y')] 11.8ms ± 4% 3.5ms ± 2% -70.28% (p=0.008 n=5+5)
gda_construction_raw_(128, 8)_['x', 'y'] 8.54ms ± 1% 4.03ms ± 2% -52.84% (p=0.008 n=5+5)
indices_replica_id_calc_(128, 8)_['x', 'y'] 5.40ms ± 4% 1.10ms ± 1% -79.69% (p=0.008 n=5+5)
gda_construction_raw_(4, 2)_['x', 'y'] 173µs ± 1% 193µs ± 3% +11.63% (p=0.008 n=5+5)
indices_replica_id_calc_(4, 2)_['x', 'y'] 127µs ± 1% 147µs ± 1% +15.57% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 421623147
2022-01-13 11:53:13 -08:00
Yuanzhong Xu
2e4687a62e
PartitionSpec: allow partially specified sharding
...
PiperOrigin-RevId: 421603194
2022-01-13 10:35:15 -08:00
George Necula
5bfe1852a4
[jax2tf] Add jax2tf_associative_scan_reductions flag
...
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
jax authors
f0e4f0472d
Merge pull request #9186 from froystig:get-default-rng
...
PiperOrigin-RevId: 421454645
2022-01-12 19:48:04 -08:00
Roy Frostig
026b91b85d
add random.default_prng_impl
to retrieve the default PRNG implementation
2022-01-12 19:13:14 -08:00
jax authors
436ce7904c
Merge pull request #9175 from froystig:custom-xform-wrappers-forward-attrs
...
PiperOrigin-RevId: 421449851
2022-01-12 19:11:22 -08:00
jax authors
f08bb50bfa
Merge pull request #8869 from mbmccoy:issue8744
...
PiperOrigin-RevId: 421404424
2022-01-12 15:07:47 -08:00
Peter Hawkins
dbf03d2ee6
[JAX] Add support for generating "equation profiles" in JAX.
...
An "equation profile" is a pprof profile that maps equations in a jaxpr to the Python stack traces at which they were generated. Pprof can be used a number of ways to analyze and visualize the result.
For example, for a profile from a Resnet-50 training step from Flax, we can identify the most common primitives:
```
$ pprof --tags /tmp/myprof
Main binary filename not available.
primitive: Total 6062.0
1509.0 (24.89%): mul
936.0 (15.44%): add
589.0 ( 9.72%): reshape
492.0 ( 8.12%): div
485.0 ( 8.00%): broadcast_in_dim
330.0 ( 5.44%): reduce_sum
322.0 ( 5.31%): integer_pow
230.0 ( 3.79%): add_any
174.0 ( 2.87%): convert_element_type
160.0 ( 2.64%): select
158.0 ( 2.61%): conv_general_dilated
116.0 ( 1.91%): sub
110.0 ( 1.81%): eq
110.0 ( 1.81%): neg
104.0 ( 1.72%): max
53.0 ( 0.87%): rsqrt
52.0 ( 0.86%): rev
49.0 ( 0.81%): custom_jvp_call_jaxpr
49.0 ( 0.81%): gt
5.0 (0.082%): xla_call
4.0 (0.066%): min
3.0 (0.049%): dot_general
3.0 (0.049%): lt
2.0 (0.033%): cos
2.0 (0.033%): exp
2.0 (0.033%): iota
2.0 (0.033%): log
2.0 (0.033%): psum
2.0 (0.033%): reduce_max
2.0 (0.033%): stop_gradient
1.0 (0.016%): argmax
1.0 (0.016%): reduce_window_max
1.0 (0.016%): select_and_scatter_add
1.0 (0.016%): transpose
1.0 (0.016%): xla_pmap
```
Or the lines of code that generated the most equations:
```
$ pprof --text /tmp/myprof
Main binary filename not available.
Type: equations
Showing nodes accounting for 6038, 99.60% of 6062 total
Dropped 5 nodes (cum <= 30)
flat flat% sum% cum cum%
1537 25.35% 25.35% 1537 25.35% _compute_stats
1484 24.48% 49.84% 1484 24.48% _normalize
849 14.01% 63.84% 6062 100% __call__
644 10.62% 74.46% 644 10.62% <unknown>
483 7.97% 82.43% 483 7.97% <unknown>
392 6.47% 88.90% 6061 100% train_step
324 5.34% 94.24% 324 5.34% <unknown>
161 2.66% 96.90% 161 2.66% <unknown>
57 0.94% 97.84% 4292 70.80% loss_fn
52 0.86% 98.70% 52 0.86% schedule
39 0.64% 99.34% 39 0.64% softmax_cross_entropy
8 0.13% 99.47% 30 0.49% compute_metrics
6 0.099% 99.57% 61 1.01% cross_entropy_loss
1 0.016% 99.59% 1321 21.79% apply_gradients
1 0.016% 99.60% 6062 100% train_and_evaluate
0 0% 99.60% 6062 100% <unknown>
0 0% 99.60% 6062 100% __init__
0 0% 99.60% 3872 63.87% _call_wrapped_method
0 0% 99.60% 6062 100% _run_and_get_tests_result
0 0% 99.60% 6062 100% _run_code_in_main
0 0% 99.60% 6062 100% _run_in_app
0 0% 99.60% 6062 100% _run_main
0 0% 99.60% 3872 63.87% apply
0 0% 99.60% 161 2.66% apply_updates
0 0% 99.60% 6062 100% main
0 0% 99.60% 6062 100% main_function
0 0% 99.60% 6062 100% run
0 0% 99.60% 6062 100% runTests
0 0% 99.60% 6062 100% run_filename_as_main
0 0% 99.60% 6062 100% run_tests
0 0% 99.60% 3872 63.87% scope_fn
0 0% 99.60% 6062 100% test_train_and_evaluate
0 0% 99.60% 1159 19.12% update_fn
0 0% 99.60% 3872 63.87% wrapped_fn
0 0% 99.60% 3872 63.87% wrapped_module_method
0 0% 99.60% 3872 63.87% wrapper
```
I highly recommend the pprof HTTP visualization, using --tagleaf to introduce pseudoframes for each primitive, and to use the "flame" visualization.
```
pprof --tagleaf=primitive --http=:8080 myprof
```
[XLA:Python] Add helpers to Traceback and for working with pprof profiles.
* Define hash and equality operators on Tracebacks.
* Add functions for converting JSON to and from pprof profile protocol buffers.
* Add a helper method that exposes PyCode_Addr2Line to Python.
PiperOrigin-RevId: 421395346
2022-01-12 14:27:57 -08:00
Mike McCoy
2e5ab11652
Resolves issue 8744
2022-01-12 21:10:45 +00:00
jax authors
2e375f04a6
Merge pull request #8915 from mattjj:post-process-revisions
...
PiperOrigin-RevId: 421311438
2022-01-12 08:57:53 -08:00
Matthew Johnson
08aec823fd
fix a custom_vjp post_process bug, related cleanups
...
related to #8783 , doesn't completely fix it
2022-01-12 07:51:50 -08:00
Roy Frostig
ddc1c3e9bd
enable custom transformation "stacking"
...
Make custom transformation wrappers such as `custom_jvp` behave
interchangeably when directly composed. For example, enable the
following usage:
```
@jax.custom_jvp
@jax.custom_transpose
def f(x): ...
@f.def_transpose
def f_t(y): ...
@f.defjvp
def f_jvp(x, tx): ...
```
In particular:
* Forward `def*` methods on custom transformations.
* Have unary `def*` methods return their argument so that, when used
as decorators, they do not replace their target with `None`.
* Fix a bug in the use of `functools.update_wrapper`: previously a
wrapper would overwrite its own attributes with those of the target
callable (including its reference to the target callable).
2022-01-11 17:55:08 -08:00
jax authors
4dd1f001c6
Merge pull request #9162 from mattjj:simplify-memoize
...
PiperOrigin-RevId: 421154196
2022-01-11 16:34:20 -08:00
Yash Katariya
fbb8b9f8c6
Benchmarks for GDA. Also move create_global_mesh
to test_utils since it was replicated in a lot of places.
...
PiperOrigin-RevId: 421142813
2022-01-11 15:43:05 -08:00
jax authors
f235edbf5c
Merge pull request #9134 from froystig:custom-transpose
...
PiperOrigin-RevId: 421119056
2022-01-11 14:00:45 -08:00
Roy Frostig
1709e06800
introduce custom_transpose and a corresponding primitive
...
Includes rules for impl, transpose, abstract eval, and xla/mlir
translation.
2022-01-11 12:51:17 -08:00
jax authors
a30ec029ee
Merge pull request #9164 from mattjj:checkify-tweaks
...
PiperOrigin-RevId: 421098252
2022-01-11 12:36:31 -08:00
Matthew Johnson
6850833c3a
checkify: tweak some organization and names
2022-01-10 21:29:12 -08:00
Matthew Johnson
1cf7d4ab5d
Copybara import of the project:
...
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:
add jax.ensure_compile_time_eval to public api
aka jax.core.eval_context
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Yash Katariya
7bc51879d4
Add tests for 0d fully replicated scalar input to pjit.
...
PiperOrigin-RevId: 420884601
2022-01-10 16:24:31 -08:00
jax authors
67723da38b
Merge pull request #9143 from mattjj:fix-jaxpr-checking-error-messages
...
PiperOrigin-RevId: 420866862
2022-01-10 15:09:04 -08:00
Peter Hawkins
9bc6d1103e
[JAX] Fix spurious inequality for two apparently equal PyTreeDefs.
...
When constructed via one path we were filling in the .custom field of nodes that weren't custom types.
Fixes https://github.com/google/jax/issues/9066
PiperOrigin-RevId: 420858917
2022-01-10 14:35:56 -08:00
Matthew Johnson
e321964245
de-duplicate util.memoize and util.cache
...
The only difference between the two was that
jax.config.jax_check_tracer_leaks disables the caching under util.cache
but not under util.memoize.
We could add that as an option on the same function if it turns out to
be important, but it seems unnecessary. Moreover there are only two
callers (in dtypes.py and in batching.py).
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2022-01-10 14:28:28 -08:00
Yash Katariya
e5e343c13e
Speed up deserialization by not doing the shard indices calculation twice. This calculation is expensive because it happens on the global mesh size scale.
...
PiperOrigin-RevId: 420855444
2022-01-10 14:26:11 -08:00
Yash Katariya
93b4554101
Do shard indices calculations only once in callbacks by using GDA fast path because indices calculation is expensive as it happens on the global mesh (if the global mesh is large). The second calculation happens in _create_local_shards which is avoided if GDA fast path is enabled.
...
PiperOrigin-RevId: 420855232
2022-01-10 14:21:54 -08:00
Peter Hawkins
f7de16c818
Disable a pjit test that is failing on GPU.
...
PiperOrigin-RevId: 420854561
2022-01-10 14:17:17 -08:00
jax authors
a41d3e2892
Merge pull request #9123 from prazek:patch-1
...
PiperOrigin-RevId: 420783231
2022-01-10 09:51:23 -08:00
Peter Hawkins
0e940b66dd
Fix spurious "donated buffers were not usable" warning when using MLIR.
...
PiperOrigin-RevId: 420782796
2022-01-10 09:47:00 -08:00
jax authors
977e142d55
Merge pull request #9154 from che-shr-cat:patch-1
...
PiperOrigin-RevId: 420782381
2022-01-10 09:42:36 -08:00
jax authors
8016026a06
Merge pull request #9128 from jakevdp:bcoo-metadata
...
PiperOrigin-RevId: 420782354
2022-01-10 09:37:48 -08:00
Peter Hawkins
5801079a4b
Enable JAX->MLIR lowering by default.
...
Before this change, JAX produces HLO using the XLA:Python builder APIs. After this change JAX produces MHLO using MLIR:Python APIs, and converts the MHLO to HLO for compilation with XLA. This is a lateral shift that should have little immediate impact, but unlocks a number of interesting opportunities in the future (e.g., mixing MLIR dialects within a JAX program).
[XLA:Python] Pass MLIR input as a std::string to work around https://github.com/pybind/pybind11/issues/2765 . A better fix would be to update pybind11 but that is hitting Windows-related hurdles; for now, just avoid relying on reference lifetime extension.
Brax: update test seeds to avoid test failures. Additional constant folding (canonicalization) in the MHLO lowering path seems to cause small numerical differences.
PiperOrigin-RevId: 420755696
2022-01-10 07:26:23 -08:00
jax authors
cdeced943b
Merge pull request #9113 from che-shr-cat:main
...
PiperOrigin-RevId: 420747734
2022-01-10 06:41:31 -08:00
che-shr-cat
d2c6c06546
Fix DeviceArray class reference
2022-01-10 17:34:09 +03:00
che-shr-cat
78977d6f5a
fix broken links and update texts in thinking_in_jax.ipynb
2022-01-10 16:19:57 +03:00
jax authors
ac5f1c4b24
Merge pull request #9126 from LenaMartens:changelist/418004472
...
PiperOrigin-RevId: 420716471
2022-01-10 03:25:27 -08:00
Matthew Johnson
3548e023ec
fix jaxpr type checking error messages
...
The pretty-printing changes a few months ago defined variable names
based on the state in JaxprPpContext instances. But that meant incorrect
variable names could be printed in jaxpr type checking error messages.
This commit correctly threads through the context so as to provide
error messages with coherent variable names.
2022-01-09 20:07:58 -08:00