Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
jax authors
ced2cbe64b
Merge pull request #10097 from lgeiger:expand-dims
...
PiperOrigin-RevId: 438649114
2022-03-31 13:39:34 -07:00
jax authors
dcd3006ae2
Merge pull request #10027 from jakevdp:fix-vmap-weaktype
...
PiperOrigin-RevId: 438565124
2022-03-31 07:41:13 -07:00
Lukas Geiger
50e8bc4514
Replace reshape
with expand_dims
if possible
2022-03-31 01:34:26 +01:00
Peter Hawkins
ade9f1a294
Share compare_mhlo function between lax.py and mlir.py.
...
Use the .shape property on RankedTensorType.
2022-03-30 17:02:19 -04:00
Jake VanderPlas
34f116c0e0
vmap: preserve weak_type in batching tracer
2022-03-30 11:06:56 -07:00
Benjamin Kramer
a04b777c54
[mhlo] Clean up ops that can use InferTensorTypeWithReify
...
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.
There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.
PiperOrigin-RevId: 438341237
2022-03-30 10:44:16 -07:00
Ayaka Mikazuki
2799bb3cde
[doc] Fix typo
2022-03-29 21:29:51 +08:00
Roy Frostig
a6a43e2715
allow for recursive uses of custom_transpose
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Reza Rahimi
8cd02946b5
Fix for hipsparse in ROCm.
2022-03-25 17:41:42 +00:00
Roy Frostig
0ada0a105e
avoid batching units in cond partial eval
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 17:42:38 -07:00
Sandeep Dasgupta
6cd9804163
Replace (deprecated) StrEnumAttr with EnumAttr.
...
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Thomas Köppe
c3a4a6e63d
Revert previous change
...
PiperOrigin-RevId: 435397906
2022-03-17 11:19:49 -07:00
Lena Martens
1d5833d2f1
Reshape top_k operand to 2D by collapsing the batch dimensions when lowering.
...
PiperOrigin-RevId: 435374934
2022-03-17 10:00:24 -07:00
Jake VanderPlas
c66f5dda60
DOC: add missing linalg functionality to docs
2022-03-15 09:55:59 -07:00
jax authors
4fba0e787f
[JAX] Update ann to use XLA based fallback ApproxTopK.
...
Other small changes:
* Restricts the operand type to float.
* Add more format annotations to the docstring.
PiperOrigin-RevId: 434749705
2022-03-15 07:50:48 -07:00
Robert Suderman
97ddf986bc
Make concatenate allow concatenation on dynamic dimensions
...
Concatenating two dynamic shapes together along those dynamic dimensions
should be allowed.
PiperOrigin-RevId: 434577959
2022-03-14 15:06:38 -07:00
Matthew Johnson
39c2f8b051
fixup from 5415306: remove extraneous lines
...
also add test
2022-03-11 15:19:10 -08:00
Peter Hawkins
051f4dd0cf
Suggest eigh() in the eig() not implemented error.
2022-03-10 08:51:13 -05:00
Sharad Vikram
2988901e6c
Refactor Jaxpr pretty-printing to use a JaxprPpSettings
named tuple
...
and thread it into `pp_eqn_rules` so the settings are used recursively
2022-03-09 17:40:05 -08:00
Roy Frostig
f7731bf959
remove _const
from public jax.lax
module
...
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
jax authors
2a3f936ffa
Merge pull request #9576 from nicholasjng:broadcast-validation
...
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73
Validate lax.broadcast_shape
inputs before control flow execution
...
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.
In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550
Increase minimum jaxlib version to 0.3.0.
2022-03-04 10:33:03 -05:00
Anselm Levskaya
f6a5f0dca2
Use real Precision type for lax.PrecisionType
...
PiperOrigin-RevId: 432413742
2022-03-04 04:21:25 -08:00
Tianjian Lu
2d17b4d637
[linalg] Fix type promotion in QDWH.
...
PiperOrigin-RevId: 432352839
2022-03-03 20:59:28 -08:00
jax authors
cf9a900d78
Merge pull request #9584 from ROCmSoftwarePlatform:rocm_refactor_jaxlib
...
PiperOrigin-RevId: 432236852
2022-03-03 11:11:02 -08:00
jax authors
98572a696c
Merge pull request #9737 from jakevdp:cleanup-constant-like
...
PiperOrigin-RevId: 431988229
2022-03-02 11:33:07 -08:00
jax authors
f1e71c11d7
[Jax] Format ann docstring.
...
PiperOrigin-RevId: 431968329
2022-03-02 10:11:52 -08:00
Jake VanderPlas
00e040e514
cleanup: remove _constant_like in favor of lax._const
2022-03-02 09:13:58 -08:00
jax authors
d9f82f7b9b
[JAX] Move experimental.ann.approx_*_k
into lax
.
...
Updated docs, tests and the example code snippets.
PiperOrigin-RevId: 431781401
2022-03-01 14:46:33 -08:00
Reza Rahimi
a0d9d81f92
Update JAX to use new math libraries in ROCm-5.0.
2022-03-01 20:02:15 +00:00
Peter Hawkins
cffe9978fb
Handle jaxpr constants correctly in MLIR lowering of conditional branches.
...
Add some dynamic type checks and type annotations to catch this kind of problem sooner.
There's no test case, because I'm not entirely sure how to make a test case for this. In fact, I'm not even sure it's legal for a conditional branch to have non-empty constants. We'll dig into that separately.
PiperOrigin-RevId: 431697808
2022-03-01 08:56:08 -08:00
Sharad Vikram
1b79caa6bd
Add separate mechanism for threading name stacks to the lowering
2022-02-23 09:59:09 -08:00
Jake VanderPlas
e13c847e04
Index update operators: add scatter_apply()
2022-02-18 09:44:40 -08:00
Roy Frostig
35fab1a95a
err on repeated axes to expand_dims
, as numpy does
2022-02-17 11:27:20 -08:00
Roy Frostig
0f7904f883
implement jnp.expand_dims
and jnp.stack
for PRNGKeyArrays
...
Also:
* fix `jnp.concatenate` and `jnp.append` for PRNGKeyArrays
* add `ndim` property to PRNGKeyArrays
* minor fix to `lax.expand_dims` with duplicate dimensions
2022-02-16 20:47:27 -08:00
Parker Schuh
662c4416a3
Merge branch 'main' into opt-barrier
2022-02-15 14:16:20 -08:00
Lena Martens
b15c7f609a
Checkify: fix check_error of nd-error.
...
PiperOrigin-RevId: 428857813
2022-02-15 13:12:53 -08:00
Parker Schuh
7ce911b8d1
Add translation rule for optimization barrier.
...
Also adds a translation rule for remat that uses the new optimization barrier
op. If you find errors, consider disabling the remat lowering using
`jax_remat_opt_barrier` config flag.
2022-02-14 12:21:16 -08:00
Peter Hawkins
29c8a04527
Fix incorrect binary search comparison in lax.select_n lowering.
...
Fixes issue in https://github.com/google/jax/discussions/9556#discussioncomment-2175113
2022-02-14 14:29:38 -05:00
jax authors
0566ea4ccd
Merge pull request #9456 from mattjj:jaxpr-pprint-color-flag-and-default
...
PiperOrigin-RevId: 428247626
2022-02-12 14:49:02 -08:00
Matthew Johnson
004bb684ea
add flag for jaxpr colorful syntax highlighting
...
set it on by default
2022-02-12 14:15:28 -08:00
Peter Hawkins
8ca6622c0b
Change lax.select_p to be an n-ary predicate, 'lax.select_n_p'. Change lax.select()
to be a thin shim around the new n-ary version.
...
Note that one key difference between `lax.select_p` and `lax.select_n_p` is that the order of the cases is reversed for boolean predicates. This merited a new name to minimize confusion.
Use lax.select_n() in conditional batching. This means that we only produce one `select_n()` primitive for each conditional output, rather than a tree. While this has no effect on the number of HLO operators we generate, it can reduces the number of jaxpr equations significantly.
PiperOrigin-RevId: 427517899
2022-02-09 11:03:09 -08:00
Peter Hawkins
f539c9b9bd
Hoist construction of predicates out of cond batching rule.
...
Avoids building the "which path are we following" predicate once for each input.
PiperOrigin-RevId: 427012972
2022-02-07 14:13:03 -08:00
Peter Hawkins
f94c42b271
Fix rendering of cube root in docs.
2022-02-04 11:30:10 -05:00
jax authors
45d96c490e
Merge pull request #4671 from romanngg:conv_local
...
PiperOrigin-RevId: 426282505
2022-02-03 18:03:33 -08:00
Tianjian Lu
5a012d5e7b
[JAX] Added jit-able singular value decomposition.
...
PiperOrigin-RevId: 426193395
2022-02-03 11:16:55 -08:00
jax authors
d04dce3fa2
Merge pull request #9417 from hawkinsp:fft2
...
PiperOrigin-RevId: 426163984
2022-02-03 09:24:13 -08:00
Peter Hawkins
84bccb2420
Support string fft_type values in lax.fft.
2022-02-03 08:52:38 -05:00