588 Commits

Author SHA1 Message Date
Peter Hawkins
0150d15cb2 Increase minimum jaxlib version to 0.3.7.
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
c4ba450867 [MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.

PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00
Eugene Burmako
b267fd4336 [MHLO] Add MHLO lowering for cosh
Here's the corresponding old bridge lowering:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1294-1309;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e

getConstantLike wasn't supporting complex numbers, and proper support required
non-trivial work, so in the meanwhile I've hacked up something that works
for static shapes to unblock the JAX use case (which currently only uses static shapes).

PiperOrigin-RevId: 442083258
2022-04-15 13:26:20 -07:00
Parker Schuh
58ef6cb9e1
Merge branch 'main' into opt-barrier-gpu 2022-04-14 19:31:12 -07:00
Peter Hawkins
c2fe97ae01 Improve precision of chlo.sinh.
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh

[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.

PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
jax authors
6914e35af1 Merge pull request #10270 from mattjj:djax-iree
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30 add some simple iree tests
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):

```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
6c1461b52b [MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
4806c29bf7 [MHLO] Add MHLO lowerings for FFT ops.
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Parker Schuh
b0c41b9023 Switch gpu, cpu and jax2tf to use the new OptimizationBarrier op.
This change should unblock the removal of the cond and while widgets.
2022-04-13 22:51:36 -07:00
Peter Hawkins
665df8dfaa [MHLO] Add an MHLO lowering for rng_bit_generator.
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
Peter Hawkins
21f95d531b Remove use of xla.lower_fun in SVD translation rule.
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.

PiperOrigin-RevId: 441538472
2022-04-13 11:44:45 -07:00
Jake VanderPlas
1a8c57d272 better errors: check for callability of lax.control_flow arguments 2022-04-13 10:39:01 -07:00
Jake VanderPlas
7bfc86e17f Fix arguments to schur translation rule 2022-04-13 09:50:33 -07:00
Peter Hawkins
cb4abe754a [MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Sharad Vikram
4392b07022 Add tests for higher order primitives 2022-04-12 18:12:44 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
jax authors
b713d3ce4b Minor change to lax to support jax2tf shape polymorphic concatenation.
PiperOrigin-RevId: 440113799
2022-04-07 08:34:27 -07:00
Peter Hawkins
cbdcdf7401 [MHLO] Add MHLO lowerings for parallel collectives.
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
Rohit Santhanam
6c560b14a7 Consolidation of hipsolver/cusolver APIs. 2022-04-07 01:46:43 +00:00
Peter Hawkins
bc658e7456 [MHLO] Add direct MHLO lowerings for most linear algebra kernels.
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Peter Hawkins
b9bb61322c [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.

Add a handful of direct MLIR lowerings for primitives that lacked them.

PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
ced2cbe64b Merge pull request #10097 from lgeiger:expand-dims
PiperOrigin-RevId: 438649114
2022-03-31 13:39:34 -07:00
jax authors
dcd3006ae2 Merge pull request #10027 from jakevdp:fix-vmap-weaktype
PiperOrigin-RevId: 438565124
2022-03-31 07:41:13 -07:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
Peter Hawkins
ade9f1a294 Share compare_mhlo function between lax.py and mlir.py.
Use the .shape property on RankedTensorType.
2022-03-30 17:02:19 -04:00
Jake VanderPlas
34f116c0e0 vmap: preserve weak_type in batching tracer 2022-03-30 11:06:56 -07:00
Benjamin Kramer
a04b777c54 [mhlo] Clean up ops that can use InferTensorTypeWithReify
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.

There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.

PiperOrigin-RevId: 438341237
2022-03-30 10:44:16 -07:00
Ayaka Mikazuki
2799bb3cde
[doc] Fix typo 2022-03-29 21:29:51 +08:00
Roy Frostig
a6a43e2715 allow for recursive uses of custom_transpose
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Reza Rahimi
8cd02946b5 Fix for hipsparse in ROCm. 2022-03-25 17:41:42 +00:00
Roy Frostig
0ada0a105e avoid batching units in cond partial eval
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 17:42:38 -07:00
Sandeep Dasgupta
6cd9804163 Replace (deprecated) StrEnumAttr with EnumAttr.
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Thomas Köppe
c3a4a6e63d Revert previous change
PiperOrigin-RevId: 435397906
2022-03-17 11:19:49 -07:00
Lena Martens
1d5833d2f1 Reshape top_k operand to 2D by collapsing the batch dimensions when lowering.
PiperOrigin-RevId: 435374934
2022-03-17 10:00:24 -07:00
Jake VanderPlas
c66f5dda60 DOC: add missing linalg functionality to docs 2022-03-15 09:55:59 -07:00
jax authors
4fba0e787f [JAX] Update ann to use XLA based fallback ApproxTopK.
Other small changes:
* Restricts the operand type to float.
* Add more format annotations to the docstring.

PiperOrigin-RevId: 434749705
2022-03-15 07:50:48 -07:00
Robert Suderman
97ddf986bc Make concatenate allow concatenation on dynamic dimensions
Concatenating two dynamic shapes together along those dynamic dimensions
should be allowed.

PiperOrigin-RevId: 434577959
2022-03-14 15:06:38 -07:00
Matthew Johnson
39c2f8b051 fixup from 5415306: remove extraneous lines
also add test
2022-03-11 15:19:10 -08:00
Peter Hawkins
051f4dd0cf Suggest eigh() in the eig() not implemented error. 2022-03-10 08:51:13 -05:00
Sharad Vikram
2988901e6c Refactor Jaxpr pretty-printing to use a JaxprPpSettings named tuple
and thread it into `pp_eqn_rules` so the settings are used recursively
2022-03-09 17:40:05 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
jax authors
2a3f936ffa Merge pull request #9576 from nicholasjng:broadcast-validation
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73 Validate lax.broadcast_shape inputs before control flow execution
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.

In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00