562 Commits

Author SHA1 Message Date
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
ced2cbe64b Merge pull request #10097 from lgeiger:expand-dims
PiperOrigin-RevId: 438649114
2022-03-31 13:39:34 -07:00
jax authors
dcd3006ae2 Merge pull request #10027 from jakevdp:fix-vmap-weaktype
PiperOrigin-RevId: 438565124
2022-03-31 07:41:13 -07:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
Peter Hawkins
ade9f1a294 Share compare_mhlo function between lax.py and mlir.py.
Use the .shape property on RankedTensorType.
2022-03-30 17:02:19 -04:00
Jake VanderPlas
34f116c0e0 vmap: preserve weak_type in batching tracer 2022-03-30 11:06:56 -07:00
Benjamin Kramer
a04b777c54 [mhlo] Clean up ops that can use InferTensorTypeWithReify
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.

There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.

PiperOrigin-RevId: 438341237
2022-03-30 10:44:16 -07:00
Ayaka Mikazuki
2799bb3cde
[doc] Fix typo 2022-03-29 21:29:51 +08:00
Roy Frostig
a6a43e2715 allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Reza Rahimi
8cd02946b5 Fix for hipsparse in ROCm. 2022-03-25 17:41:42 +00:00
Roy Frostig
0ada0a105e avoid batching units in cond partial eval
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 17:42:38 -07:00
Sandeep Dasgupta
6cd9804163 Replace (deprecated) StrEnumAttr with EnumAttr.
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Thomas Köppe
c3a4a6e63d Revert previous change
PiperOrigin-RevId: 435397906
2022-03-17 11:19:49 -07:00
Lena Martens
1d5833d2f1 Reshape top_k operand to 2D by collapsing the batch dimensions when lowering.
PiperOrigin-RevId: 435374934
2022-03-17 10:00:24 -07:00
Jake VanderPlas
c66f5dda60 DOC: add missing linalg functionality to docs 2022-03-15 09:55:59 -07:00
jax authors
4fba0e787f [JAX] Update ann to use XLA based fallback ApproxTopK.
Other small changes:
* Restricts the operand type to float.
* Add more format annotations to the docstring.

PiperOrigin-RevId: 434749705
2022-03-15 07:50:48 -07:00
Robert Suderman
97ddf986bc Make concatenate allow concatenation on dynamic dimensions
Concatenating two dynamic shapes together along those dynamic dimensions
should be allowed.

PiperOrigin-RevId: 434577959
2022-03-14 15:06:38 -07:00
Matthew Johnson
39c2f8b051 fixup from 5415306: remove extraneous lines
also add test
2022-03-11 15:19:10 -08:00
Peter Hawkins
051f4dd0cf Suggest eigh() in the eig() not implemented error. 2022-03-10 08:51:13 -05:00
Sharad Vikram
2988901e6c Refactor Jaxpr pretty-printing to use a JaxprPpSettings named tuple
and thread it into `pp_eqn_rules` so the settings are used recursively
2022-03-09 17:40:05 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
jax authors
2a3f936ffa Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73 Validate lax.broadcast_shape inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00
Anselm Levskaya
f6a5f0dca2 Use real Precision type for lax.PrecisionType
PiperOrigin-RevId: 432413742
2022-03-04 04:21:25 -08:00
Tianjian Lu
2d17b4d637 [linalg] Fix type promotion in QDWH.
PiperOrigin-RevId: 432352839
2022-03-03 20:59:28 -08:00
jax authors
cf9a900d78 Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
jax authors
98572a696c Merge pull request #9737 from jakevdp:cleanup-constant-like
PiperOrigin-RevId: 431988229
2022-03-02 11:33:07 -08:00
jax authors
f1e71c11d7 [Jax] Format ann docstring.
PiperOrigin-RevId: 431968329
2022-03-02 10:11:52 -08:00
Jake VanderPlas
00e040e514 cleanup: remove _constant_like in favor of lax._const 2022-03-02 09:13:58 -08:00
jax authors
d9f82f7b9b [JAX] Move experimental.ann.approx_*_k into lax.
Updated docs, tests and the example code snippets.

PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Reza Rahimi
a0d9d81f92 Update JAX to use new math libraries in ROCm-5.0. 2022-03-01 20:02:15 +00:00
Peter Hawkins
cffe9978fb Handle jaxpr constants correctly in MLIR lowering of conditional branches.
Add some dynamic type checks and type annotations to catch this kind of problem sooner.

There's no test case, because I'm not entirely sure how to make a test case for this. In fact, I'm not even sure it's legal for a conditional branch to have non-empty constants. We'll dig into that separately.

PiperOrigin-RevId: 431697808
2022-03-01 08:56:08 -08:00
Sharad Vikram
1b79caa6bd Add separate mechanism for threading name stacks to the lowering 2022-02-23 09:59:09 -08:00
Jake VanderPlas
e13c847e04 Index update operators: add scatter_apply() 2022-02-18 09:44:40 -08:00
Roy Frostig
35fab1a95a err on repeated axes to expand_dims, as numpy does 2022-02-17 11:27:20 -08:00
Roy Frostig
0f7904f883 implement jnp.expand_dims and jnp.stack for PRNGKeyArrays
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Parker Schuh
662c4416a3
Merge branch 'main' into opt-barrier 2022-02-15 14:16:20 -08:00
Lena Martens
b15c7f609a Checkify: fix check_error of nd-error.
PiperOrigin-RevId: 428857813
2022-02-15 13:12:53 -08:00
Parker Schuh
7ce911b8d1 Add translation rule for optimization barrier.
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
2022-02-14 12:21:16 -08:00
Peter Hawkins
29c8a04527 Fix incorrect binary search comparison in lax.select_n lowering.
Fixes issue in https://github.com/google/jax/discussions/9556#discussioncomment-2175113
2022-02-14 14:29:38 -05:00
jax authors
0566ea4ccd Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
PiperOrigin-RevId: 428247626
2022-02-12 14:49:02 -08:00
Matthew Johnson
004bb684ea add flag for jaxpr colorful syntax highlighting
set it on by default
2022-02-12 14:15:28 -08:00
Peter Hawkins
8ca6622c0b Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select() to be a thin shim around the new n-ary version.
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.

Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.

PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
Peter Hawkins
f539c9b9bd Hoist construction of predicates out of cond batching rule.
Avoids building the "which path are we following" predicate once for each input.

PiperOrigin-RevId: 427012972
2022-02-07 14:13:03 -08:00
Peter Hawkins
f94c42b271 Fix rendering of cube root in docs. 2022-02-04 11:30:10 -05:00
jax authors
45d96c490e Merge pull request #4671 from romanngg:conv_local
PiperOrigin-RevId: 426282505
2022-02-03 18:03:33 -08:00
Tianjian Lu
5a012d5e7b [JAX] Added jit-able singular value decomposition.
PiperOrigin-RevId: 426193395
2022-02-03 11:16:55 -08:00
jax authors
d04dce3fa2 Merge pull request #9417 from hawkinsp:fft2
PiperOrigin-RevId: 426163984
2022-02-03 09:24:13 -08:00
Peter Hawkins
84bccb2420 Support string fft_type values in lax.fft. 2022-02-03 08:52:38 -05:00