Peter Hawkins
0150d15cb2
Increase minimum jaxlib version to 0.3.7.
...
Drop backwards compatibility with older jaxlib versions.
2022-04-18 08:09:50 -04:00
Peter Hawkins
c4ba450867
[MHLO] Add explicit XLA translation rules for primitives that lack MHLO lowerings that rely on standard_primitive registering a translation rule.
...
At the moment this change does nothing since standard_primitive already registers these same translation rules. The change is in preparation for removing the behavior of standard_primitive of registering an XLA translation rule.
PiperOrigin-RevId: 442222533
2022-04-16 07:01:19 -07:00
Eugene Burmako
b267fd4336
[MHLO] Add MHLO lowering for cosh
...
Here's the corresponding old bridge lowering:
https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1294-1309;drc=dca7bec824dceaae1d28bd4bce7addb4444e0d3e
getConstantLike wasn't supporting complex numbers, and proper support required
non-trivial work, so in the meanwhile I've hacked up something that works
for static shapes to unblock the JAX use case (which currently only uses static shapes).
PiperOrigin-RevId: 442083258
2022-04-15 13:26:20 -07:00
Parker Schuh
58ef6cb9e1
Merge branch 'main' into opt-barrier-gpu
2022-04-14 19:31:12 -07:00
Peter Hawkins
c2fe97ae01
Improve precision of chlo.sinh.
...
Update chlo.sinh lowering to match xla::Sinh(), see https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/client/lib/math.cc;l=1311?q=xla%20sinh
[JAX] Use chlo.sinh instead of the XLA client library HLO lowering.
PiperOrigin-RevId: 441851170
2022-04-14 14:10:26 -07:00
jax authors
6914e35af1
Merge pull request #10270 from mattjj:djax-iree
...
PiperOrigin-RevId: 441812895
2022-04-14 11:33:10 -07:00
Matthew Johnson
d21b958f30
add some simple iree tests
...
This passes, though two of the interesting tests fail with what might be IREE
bugs (and so are currently skipped):
```shell
JAX_PLATFORMS='iree' pytest -n auto tests/core_test.py tests/api_test.py -k Dynamic
```
2022-04-14 10:55:00 -07:00
Peter Hawkins
6c1461b52b
[MHLO] Add MHLO lowerings for triangular_solve, cholesky, and schur.
...
PiperOrigin-RevId: 441769591
2022-04-14 08:38:21 -07:00
Peter Hawkins
4806c29bf7
[MHLO] Add MHLO lowerings for FFT ops.
...
PiperOrigin-RevId: 441768017
2022-04-14 08:31:17 -07:00
Parker Schuh
b0c41b9023
Switch gpu, cpu and jax2tf to use the new OptimizationBarrier op.
...
This change should unblock the removal of the cond and while widgets.
2022-04-13 22:51:36 -07:00
Peter Hawkins
665df8dfaa
[MHLO] Add an MHLO lowering for rng_bit_generator.
...
PiperOrigin-RevId: 441628987
2022-04-13 18:09:36 -07:00
Peter Hawkins
21f95d531b
Remove use of xla.lower_fun in SVD translation rule.
...
This is the only use of xla.lower_fun that is still needed (as a fallback) when the non-MHLO path is removed.
PiperOrigin-RevId: 441538472
2022-04-13 11:44:45 -07:00
Jake VanderPlas
1a8c57d272
better errors: check for callability of lax.control_flow arguments
2022-04-13 10:39:01 -07:00
Jake VanderPlas
7bfc86e17f
Fix arguments to schur translation rule
2022-04-13 09:50:33 -07:00
Peter Hawkins
cb4abe754a
[MHLO] Separate registrations for collective and initial_style primitives from the XLA translation rule registration.
...
Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.
PiperOrigin-RevId: 441474701
2022-04-13 07:26:26 -07:00
Peter Hawkins
ad8e6ada4e
[MHLO] Change jax.xla_computation() to use MHLO lowering internally.
...
Change in preparation for removing the non-MHLO lowering path.
PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Sharad Vikram
4392b07022
Add tests for higher order primitives
2022-04-12 18:12:44 -07:00
Matthew Johnson
4354f355a8
prototyping dynamic shapes
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Sharad Vikram
0fa1eddd25
Adds simple effect types to jaxprs
2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2
Remove invertible_ad since it's not in use.
...
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
0bfb3efcd7
[JAX] Fix batch logic for approx_min/max_k
...
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
jax authors
b713d3ce4b
Minor change to lax to support jax2tf shape polymorphic concatenation.
...
PiperOrigin-RevId: 440113799
2022-04-07 08:34:27 -07:00
Peter Hawkins
cbdcdf7401
[MHLO] Add MHLO lowerings for parallel collectives.
...
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
Rohit Santhanam
6c560b14a7
Consolidation of hipsolver/cusolver APIs.
2022-04-07 01:46:43 +00:00
Peter Hawkins
bc658e7456
[MHLO] Add direct MHLO lowerings for most linear algebra kernels.
...
PiperOrigin-RevId: 439927594
2022-04-06 13:59:09 -07:00
Peter Hawkins
b9bb61322c
[MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings.
...
This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.
Add a handful of direct MLIR lowerings for primitives that lacked them.
PiperOrigin-RevId: 439912093
2022-04-06 12:53:56 -07:00
Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
jax authors
ced2cbe64b
Merge pull request #10097 from lgeiger:expand-dims
...
PiperOrigin-RevId: 438649114
2022-03-31 13:39:34 -07:00
jax authors
dcd3006ae2
Merge pull request #10027 from jakevdp:fix-vmap-weaktype
...
PiperOrigin-RevId: 438565124
2022-03-31 07:41:13 -07:00
Lukas Geiger
50e8bc4514
Replace reshape
with expand_dims
if possible
2022-03-31 01:34:26 +01:00
Peter Hawkins
ade9f1a294
Share compare_mhlo function between lax.py and mlir.py.
...
Use the .shape property on RankedTensorType.
2022-03-30 17:02:19 -04:00
Jake VanderPlas
34f116c0e0
vmap: preserve weak_type in batching tracer
2022-03-30 11:06:56 -07:00
Benjamin Kramer
a04b777c54
[mhlo] Clean up ops that can use InferTensorTypeWithReify
...
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.
There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.
PiperOrigin-RevId: 438341237
2022-03-30 10:44:16 -07:00
Ayaka Mikazuki
2799bb3cde
[doc] Fix typo
2022-03-29 21:29:51 +08:00
Roy Frostig
a6a43e2715
allow for recursive uses of custom_transpose
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-26 12:09:15 -07:00
Reza Rahimi
8cd02946b5
Fix for hipsparse in ROCm.
2022-03-25 17:41:42 +00:00
Roy Frostig
0ada0a105e
avoid batching units in cond partial eval
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-03-22 17:42:38 -07:00
Sandeep Dasgupta
6cd9804163
Replace (deprecated) StrEnumAttr with EnumAttr.
...
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
2022-03-17 23:11:28 -07:00
Thomas Köppe
c3a4a6e63d
Revert previous change
...
PiperOrigin-RevId: 435397906
2022-03-17 11:19:49 -07:00
Lena Martens
1d5833d2f1
Reshape top_k operand to 2D by collapsing the batch dimensions when lowering.
...
PiperOrigin-RevId: 435374934
2022-03-17 10:00:24 -07:00
Jake VanderPlas
c66f5dda60
DOC: add missing linalg functionality to docs
2022-03-15 09:55:59 -07:00
jax authors
4fba0e787f
[JAX] Update ann to use XLA based fallback ApproxTopK.
...
Other small changes:
* Restricts the operand type to float.
* Add more format annotations to the docstring.
PiperOrigin-RevId: 434749705
2022-03-15 07:50:48 -07:00
Robert Suderman
97ddf986bc
Make concatenate allow concatenation on dynamic dimensions
...
Concatenating two dynamic shapes together along those dynamic dimensions
should be allowed.
PiperOrigin-RevId: 434577959
2022-03-14 15:06:38 -07:00
Matthew Johnson
39c2f8b051
fixup from 5415306: remove extraneous lines
...
also add test
2022-03-11 15:19:10 -08:00
Peter Hawkins
051f4dd0cf
Suggest eigh() in the eig() not implemented error.
2022-03-10 08:51:13 -05:00
Sharad Vikram
2988901e6c
Refactor Jaxpr pretty-printing to use a JaxprPpSettings
named tuple
...
and thread it into `pp_eqn_rules` so the settings are used recursively
2022-03-09 17:40:05 -08:00
Roy Frostig
f7731bf959
remove _const
from public jax.lax
module
...
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
jax authors
2a3f936ffa
Merge pull request #9576 from nicholasjng:broadcast-validation
...
PiperOrigin-RevId: 432531230
2022-03-04 14:21:17 -08:00
Nicholas Junge
56546d3e73
Validate lax.broadcast_shape
inputs before control flow execution
...
This commit addresses previously unvalidated inputs to `jax.lax.broadcast_shapes` by adding a small validation check before control flow execution. A legal input to `lax.broadcast_shapes` hereafter is defined as an input that
1) is a sequence (i.e., implements for..in iteration) of integers and
2) said integers are all non-negative.
In addition, two tests were added to `tests.lax_vmap_test` to check that proper errors are raised when attempting to use illegal inputs with `lax.broadcast_shapes`.
2022-03-04 19:27:52 +01:00
Peter Hawkins
c978df5550
Increase minimum jaxlib version to 0.3.0.
2022-03-04 10:33:03 -05:00