Yash Katariya
6ba9fb699d
Upgrade the bazel version to 5.1.1
...
PiperOrigin-RevId: 441338363
jax-v0.3.6
2022-04-12 17:48:09 -07:00
jax authors
c06eff8cd8
Merge pull request #10245 from google:yashk2810-patch-7
...
PiperOrigin-RevId: 441265709
2022-04-12 12:54:22 -07:00
Yash Katariya
5fd78eaf02
Bump the libtpu version to prepare for JAX release
2022-04-12 11:41:07 -07:00
Peter Hawkins
9455254b9f
[MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.
PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Yash Katariya
3136004c62
Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class.
...
```
Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count]
Expected: (self)
Actually passed: (self, _, _)
```
PiperOrigin-RevId: 441211351
2022-04-12 09:36:28 -07:00
jax authors
a2c2d9af91
[JAX] Adds the approx_top_k_p bridge.
...
PiperOrigin-RevId: 441172779
2022-04-12 06:52:47 -07:00
jax authors
7be37ab5d6
Merge pull request #10120 from mattjj:djax-latest
...
PiperOrigin-RevId: 441088673
2022-04-11 22:42:07 -07:00
Matthew Johnson
4354f355a8
prototyping dynamic shapes
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
jax authors
fb6a143d4d
Merge pull request #9723 from sharadmv:jaxpr-effects
...
PiperOrigin-RevId: 441083960
2022-04-11 22:05:10 -07:00
Peter Hawkins
b051be4a7a
[MHLO] Switch sharded_jit dispatch path to use MHLO lowering.
...
This change runs some chance of breaking sharded_jit users due to a lack of testing, but the plan is to delete sharded_jit very soon anyway.
PiperOrigin-RevId: 440999535
2022-04-11 14:41:51 -07:00
Sharad Vikram
0fa1eddd25
Adds simple effect types to jaxprs
2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2
Remove invertible_ad since it's not in use.
...
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
35b32eef96
Merge pull request #10215 from mattjj:dispatch-tweaks
...
PiperOrigin-RevId: 440610850
2022-04-09 14:25:01 -07:00
Matthew Johnson
9f1fab2519
dispatch.py: type annotations, other minor tweaks
2022-04-09 14:07:19 -07:00
Yash Katariya
5fdad0ebf5
Roll forward manylinux2014 builds after fixes.
...
PiperOrigin-RevId: 440589273
2022-04-09 10:19:44 -07:00
Tianjian Lu
a11b41f581
[sparse] Use sorted indices instead of sorted rows only.
...
PiperOrigin-RevId: 440579642
2022-04-09 08:33:48 -07:00
Yash Katariya
e9f95fa5fa
Make jaxlib builds manylinux2014 compliant.
...
PiperOrigin-RevId: 440497401
2022-04-08 18:51:46 -07:00
Yash Katariya
506a85b7ff
Make jaxlib builds manylinux2014 compliant.
...
PiperOrigin-RevId: 440476417
2022-04-08 16:21:56 -07:00
Peter Hawkins
94307a02c8
Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
PiperOrigin-RevId: 440452521
2022-04-08 14:22:15 -07:00
Matthew Johnson
272ed95858
remove experimental/djax
...
PiperOrigin-RevId: 440445082
2022-04-08 13:55:22 -07:00
jax authors
0bfb3efcd7
[JAX] Fix batch logic for approx_min/max_k
...
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.
PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
jax authors
6cb7526390
Merge pull request #10176 from lgeiger:simplify-diagonal
...
PiperOrigin-RevId: 440437605
2022-04-08 13:19:09 -07:00
Peter Hawkins
0f15fa3b10
[MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
...
PiperOrigin-RevId: 440433044
2022-04-08 12:57:59 -07:00
jax authors
b8602d018b
Merge pull request #10198 from jakevdp:bcoo-duplicates
...
PiperOrigin-RevId: 440423326
2022-04-08 12:12:12 -07:00
Yash Katariya
654e5bd922
Roll forward again after the fix in the auto sharding pass.
...
PiperOrigin-RevId: 440412218
2022-04-08 11:25:07 -07:00
Jake VanderPlas
8b9efe79e7
[sparse] fix autodiff bug in spdot_general
2022-04-08 11:04:26 -07:00
Lukas Geiger
60c828a78a
Simplify jnp.trace
implementation
2022-04-08 18:10:01 +01:00
Lukas Geiger
084adc7b79
Simplify jnp.diagonal
implementation
2022-04-08 18:07:32 +01:00
Peter Hawkins
8b6b736ef3
Revert: Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds.
...
The @platforms repository has been updated in the @tf_runtime repository, which was pulling in the old version of @platforms. We no longer need to override @platforms in the JAX WORKSPACE.
PiperOrigin-RevId: 440375016
2022-04-08 08:49:00 -07:00
Peter Hawkins
648a512488
[MHLO] Add direct MHLO lowerings for sparse primitives.
...
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
jax authors
1cb4fccd1d
Merge pull request #10187 from hawkinsp:jaxlib
...
PiperOrigin-RevId: 440362035
2022-04-08 07:41:33 -07:00
Joan Puigcerver
0c02f7935a
Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
...
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Peter Hawkins
4dc69034b0
Update version numbers after jax/jaxlib release.
2022-04-07 16:40:19 -04:00
jax authors
58bdcb89e8
Merge pull request #10185 from hawkinsp:jaxlib
...
PiperOrigin-RevId: 440184287
2022-04-07 13:20:36 -07:00
Peter Hawkins
7f751c5523
Update libtpu version for jax 0.3.5 release.
jax-v0.3.5
2022-04-07 16:14:13 -04:00
jax authors
28cb44e8f6
Merge pull request #10184 from jakevdp:merge-bcoo-dot-general
...
PiperOrigin-RevId: 440178509
2022-04-07 12:57:19 -07:00
Yash Katariya
96af4d5393
Remove sharded_jit usage from jax2tf because sharded_jit is deprecated.
...
PiperOrigin-RevId: 440169129
2022-04-07 12:14:10 -07:00
Jake VanderPlas
01e4fa8a78
[sparse] consolidate flavors of bcoo_dot_general
2022-04-07 11:28:12 -07:00
Lena Martens
5522ed1702
jax2tf: Support uint32 keys in rng_bit_generator.
...
This follows the rng_bit_generator_translation rule in JAX, which allows for
both uint32 and uint64 keys and casts between them. The default rbg prng
implementation in JAX uses a (4,) uint32 key.
PiperOrigin-RevId: 440124048
2022-04-07 09:19:22 -07:00
jax authors
02fd8752bd
Add __init__ to PolyShape.
...
PiperOrigin-RevId: 440120323
2022-04-07 09:06:37 -07:00
jax authors
b713d3ce4b
Minor change to lax to support jax2tf shape polymorphic concatenation.
...
PiperOrigin-RevId: 440113799
2022-04-07 08:34:27 -07:00
Peter Hawkins
cbdcdf7401
[MHLO] Add MHLO lowerings for parallel collectives.
...
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
jax authors
28842151c6
Merge pull request #10167 from ROCmSoftwarePlatform:rocm_solver_api_consolidation
...
PiperOrigin-RevId: 439997492
2022-04-06 20:00:20 -07:00
jax authors
8b3f039252
Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
...
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
Rohit Santhanam
6c560b14a7
Consolidation of hipsolver/cusolver APIs.
2022-04-07 01:46:43 +00:00
jax authors
832d9aa435
Merge pull request #10175 from jakevdp:bcoo-spdot
...
PiperOrigin-RevId: 439980561
2022-04-06 18:01:05 -07:00
jax authors
700d1e1e63
Merge pull request #10177 from hawkinsp:jaxlib
...
PiperOrigin-RevId: 439975370
jaxlib-v0.3.5
2022-04-06 17:27:38 -07:00
Peter Hawkins
96ba290faf
Jax 0.3.5 and jaxlib 0.3.5 release.
2022-04-06 23:56:41 +00:00
jax authors
7ee6adb1a5
Merge pull request #10173 from jakevdp:bcoo-add-batchdim
...
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Jake VanderPlas
aa0da8e8e7
[sparse] make bcoo_spdot_general return a BCOO array, not raw buffers
2022-04-06 15:40:26 -07:00