3547 Commits

Author SHA1 Message Date
Peter Hawkins
4679f455f9 Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
5415306257 Make lax.reduce_window variadic.
This is similar to the support in lax.reduce(), where the operands and init_values become pytrees. This is a strict superset of the current API, so users should not need updates.

Variadic lax.reduce_window() is only supported on CPU and TPU at the moment, not GPU.

PiperOrigin-RevId: 411632993
2021-11-22 13:21:37 -08:00
Peter Hawkins
ad6ce74d67 Skip some polar decomposition tests that fail on A100.
Works around https://github.com/google/jax/issues/8628

PiperOrigin-RevId: 411604717
2021-11-22 11:18:22 -08:00
Peter Hawkins
f4351e8419 Disable QDWH tests that fail on GPU and TPU.
PiperOrigin-RevId: 411591003
2021-11-22 10:21:41 -08:00
Peter Hawkins
dcded6a8f9 Fix incorrect gradient for base-dilated reduce window.
https://github.com/google/jax/pull/8606 introduced a runtime error where as a consequence of the move, a reference to `slice` became a reference to the builtin slice operator instead of `lax.slice`.

After fixing that and while added a test, I noticed that the gradient was wrong before: we should have been slicing the result, not the operand in the transpose rule's handling of base dilation.

Also enable some TPU tests that now pass since we have variadic reduce-window support on TPU.

PiperOrigin-RevId: 411579650
2021-11-22 09:34:10 -08:00
Peter Hawkins
d262bae88b Split jax.interpreters.xla up into three pieces:
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).

The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)

PiperOrigin-RevId: 411565432
2021-11-22 08:22:43 -08:00
Yash Katariya
65a99dba7c Add local_data API for GSDA
PiperOrigin-RevId: 411188164
2021-11-19 18:56:17 -08:00
Yash Katariya
d1de309410 Adding support for a special value for in_axis_resources (pjit.FROM_GSDA) when GSDA is an input.
PiperOrigin-RevId: 411148899
2021-11-19 14:48:11 -08:00
jax authors
34a2ffcfb6 Merge pull request #8624 from jakevdp:quantile-complex
PiperOrigin-RevId: 411148049
2021-11-19 14:43:38 -08:00
jax authors
cad3fd808d Make it easier to see what are the difference between two structures.
When structures are very large, users can end up with pages and pages describing the two structures, and finding exactly where they differ can be tricky. This change makes these differences more obvious.

PiperOrigin-RevId: 411131921
2021-11-19 13:33:30 -08:00
Jake VanderPlas
72276366a9 jnp.quantile: explicitly raise error for complex input 2021-11-19 10:54:09 -08:00
jax authors
f08a5a07a8 Merge pull request #8552 from mattjj:elide-more-convert-element-types
PiperOrigin-RevId: 411082070
2021-11-19 09:44:30 -08:00
Matthew Johnson
abbf78b5c3 generalize jaxpr simplification machinery
also:
* fix jit invariance bug around weak types
* elide trivial broadcasts

This started as an attempt to simplify some jaxpr pretty-prints, by (1)
eliding some convert_element_type applications that I thought were
unnecessary and (2) eliding some trivial broadcasts.

But it turned out that we were actually pruning more
convert_element_types than we should! In particular, see
test_weak_type_jit_invariance; that test fails on the main branch even
if we add the fixes in DynamicJaxprTrace.new_const, because [this
logic](b53a174042/jax/interpreters/partial_eval.py (L1225))
was not paying attention to weak types and hence clobbered them.

In addition to fixing those bugs that turned up (the changes in
DynamicJaxprTrace, and in what is now _convert_elt_type_fwd_rule), this
PR generalizes the jaxpr simplification machinery so as not to be a
couple special cases on convert_element_type_p. Insetad, we have tables
of rules! How we love them.

These rule signatures should let us add simplifications like forwarding
variables through calls and other higher-order primitives. That's all
future work though.
2021-11-19 09:00:59 -08:00
jax authors
f391b5b580 Merge pull request #8565 from jakevdp:spdot-general
PiperOrigin-RevId: 410905194
2021-11-18 14:53:06 -08:00
jax authors
3ee76a8089 Merge pull request #8601 from mattjj:fix-vmap-ppermute
PiperOrigin-RevId: 410900441
2021-11-18 14:36:00 -08:00
Jake VanderPlas
848675df45 [sparse] spdot_general: implement many more cases 2021-11-18 14:32:58 -08:00
jax authors
d42255486b Merge pull request #8584 from jakevdp:fix-sum-duplicates
PiperOrigin-RevId: 410900403
2021-11-18 14:31:31 -08:00
jax authors
ef6fb074c9 Merge pull request #8598 from hawkinsp:jaxlib
PiperOrigin-RevId: 410895340
2021-11-18 14:11:56 -08:00
Jake VanderPlas
acca0bfa74 [sparse] fix batched version of BCOO.sum_duplicates 2021-11-18 14:11:01 -08:00
jax authors
fa5520bc90 Merge pull request #8567 from jakevdp:unique-fill-value
PiperOrigin-RevId: 410893781
2021-11-18 14:07:30 -08:00
jax authors
bf74f2e50c Merge pull request #8581 from jakevdp:sparse-matmul-dtype
PiperOrigin-RevId: 410893387
2021-11-18 14:03:03 -08:00
Matthew Johnson
2cb235809a make vmap ppermute consistent with pmap/docstring
This was a bad bug! Unfortunately our tests didn't catch it, in part
because permutations on size-two axes are either trivial or not. The
simplest test might have a size-three axis.
2021-11-18 14:02:49 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
George Necula
3715fcb930 Added workaround for bug in XLA 2021-11-18 11:01:50 +02:00
jax authors
21d4ca3a4c Merge pull request #8583 from jakevdp:remove-prints
PiperOrigin-RevId: 410715862
2021-11-17 22:56:46 -08:00
Jake VanderPlas
0bee9b3dbc jnp.unique: ensure that output dtype is not affected by fill_value 2021-11-17 16:51:21 -08:00
Jake VanderPlas
7ce5568435 [sparse] Improve type safety of cusparse lowerings
Fixes https://github.com/google/jax/issues/8577

PiperOrigin-RevId: 410624036
2021-11-17 14:05:30 -08:00
Jake VanderPlas
bef54603f3 Remove stray print statements 2021-11-17 13:57:01 -08:00
Jake VanderPlas
50ce1db879 [sparse] support type promotion in CSR/COO matmul 2021-11-17 12:57:18 -08:00
Tom Hennigan
bb3f19891e Ensure that size property of large ShardedDeviceArrays does not overflow.
This tests a fix that landed in XLA commit tensorflow/tensorflow@4216a88.

PiperOrigin-RevId: 410557846
2021-11-17 10:01:51 -08:00
jax authors
1063b7d3b7 Merge pull request #8559 from jakevdp:sparse-shape-tuple
PiperOrigin-RevId: 410506456
2021-11-17 06:10:38 -08:00
jax authors
f067d0d663 Merge pull request #8544 from jakevdp:test-arange
PiperOrigin-RevId: 410506434
2021-11-17 06:05:53 -08:00
jax authors
dad23cea2a Merge pull request #8560 from jakevdp:bcoo-dedupe
PiperOrigin-RevId: 410453183
2021-11-17 00:35:17 -08:00
jax authors
9e09b511f9 Merge pull request #8381 from LenaMartens:changelist/405399581
PiperOrigin-RevId: 410371788
2021-11-16 15:55:44 -08:00
jax authors
6883571c06 Merge pull request #8561 from mattjj:add-donated-invars-to-xlacomputation
PiperOrigin-RevId: 410368194
2021-11-16 15:40:25 -08:00
Lena Martens
e14fea3b63 Overload jnp ops which are polymorphic to an array's value and support PRNGKeys. 2021-11-16 23:00:32 +00:00
Matthew Johnson
5d35b8a119 add donated_invars to xla.XlaComputation
Co-authored-by: Brennan Saeta <saeta@google.com>
2021-11-16 13:41:21 -08:00
Jake VanderPlas
cd531f2521 [sparse] add BCOO.sum_duplicates() with nse option 2021-11-16 11:22:03 -08:00
Jake VanderPlas
5c31e6ddb5 [sparse] ensure shapes are represented as tuples 2021-11-16 09:05:30 -08:00
Jake VanderPlas
1137aa11bf Properly handle bfloat16 in jnp.load() 2021-11-16 09:04:35 -08:00
jax authors
476ca94379 Merge pull request #8549 from jurahul:master
PiperOrigin-RevId: 410141717
2021-11-15 19:39:09 -08:00
Rahul Joshi
0776c4e628 Enable testWhileLoopBatchedWithConstBody for GPU
The XLA:GPU issue causing the internal error has been fixed.
2021-11-15 16:48:52 -08:00
Jake VanderPlas
fbd9009c54 Add test of jnp.arange() corner case 2021-11-15 16:22:04 -08:00
Jake VanderPlas
e4291e0b49 [sparse] re-enable bcoo_spdot_general test 2021-11-15 16:21:42 -08:00
jax authors
be751d1dd6 Merge pull request #8534 from jakevdp:array-dtype
PiperOrigin-RevId: 410106997
2021-11-15 16:21:00 -08:00
jax authors
6fa860d5ac Internal change
PiperOrigin-RevId: 409591497
2021-11-12 22:42:15 -08:00
jax authors
e94cc97d70 Output GSDAs from pjit if jax_gsda_out flag is enabled.
PiperOrigin-RevId: 409585439
2021-11-12 21:47:10 -08:00
Yash Katariya
155475de6f Output GSDAs from pjit if jax_gsda_out flag is enabled.
PiperOrigin-RevId: 409573181
2021-11-12 19:59:51 -08:00
Jake VanderPlas
960f2c1372 [x64] jnp.array: improve type inference testing 2021-11-12 15:34:45 -08:00
jax authors
eeb9bf7a47 Merge pull request #8520 from jakevdp:fix-percentile
PiperOrigin-RevId: 409495454
2021-11-12 14:00:29 -08:00