11081 Commits

Author SHA1 Message Date
Yash Katariya
6ba9fb699d Upgrade the bazel version to 5.1.1
PiperOrigin-RevId: 441338363
jax-v0.3.6
2022-04-12 17:48:09 -07:00
jax authors
c06eff8cd8 Merge pull request #10245 from google:yashk2810-patch-7
PiperOrigin-RevId: 441265709
2022-04-12 12:54:22 -07:00
Yash Katariya
5fd78eaf02
Bump the libtpu version to prepare for JAX release 2022-04-12 11:41:07 -07:00
Peter Hawkins
9455254b9f [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted.

PiperOrigin-RevId: 441214076
2022-04-12 10:30:52 -07:00
Yash Katariya
3136004c62 Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class.
```
Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count]
         Expected: (self)
  Actually passed: (self, _, _)
```

PiperOrigin-RevId: 441211351
2022-04-12 09:36:28 -07:00
jax authors
a2c2d9af91 [JAX] Adds the approx_top_k_p bridge.
PiperOrigin-RevId: 441172779
2022-04-12 06:52:47 -07:00
jax authors
7be37ab5d6 Merge pull request #10120 from mattjj:djax-latest
PiperOrigin-RevId: 441088673
2022-04-11 22:42:07 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
jax authors
fb6a143d4d Merge pull request #9723 from sharadmv:jaxpr-effects
PiperOrigin-RevId: 441083960
2022-04-11 22:05:10 -07:00
Peter Hawkins
b051be4a7a [MHLO] Switch sharded_jit dispatch path to use MHLO lowering.
This change runs some chance of breaking sharded_jit users due to a lack of testing, but the plan is to delete sharded_jit very soon anyway.

PiperOrigin-RevId: 440999535
2022-04-11 14:41:51 -07:00
Sharad Vikram
0fa1eddd25 Adds simple effect types to jaxprs 2022-04-11 11:50:41 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
jax authors
35b32eef96 Merge pull request #10215 from mattjj:dispatch-tweaks
PiperOrigin-RevId: 440610850
2022-04-09 14:25:01 -07:00
Matthew Johnson
9f1fab2519 dispatch.py: type annotations, other minor tweaks 2022-04-09 14:07:19 -07:00
Yash Katariya
5fdad0ebf5 Roll forward manylinux2014 builds after fixes.
PiperOrigin-RevId: 440589273
2022-04-09 10:19:44 -07:00
Tianjian Lu
a11b41f581 [sparse] Use sorted indices instead of sorted rows only.
PiperOrigin-RevId: 440579642
2022-04-09 08:33:48 -07:00
Yash Katariya
e9f95fa5fa Make jaxlib builds manylinux2014 compliant.
PiperOrigin-RevId: 440497401
2022-04-08 18:51:46 -07:00
Yash Katariya
506a85b7ff Make jaxlib builds manylinux2014 compliant.
PiperOrigin-RevId: 440476417
2022-04-08 16:21:56 -07:00
Peter Hawkins
94307a02c8 Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440452521
2022-04-08 14:22:15 -07:00
Matthew Johnson
272ed95858 remove experimental/djax
PiperOrigin-RevId: 440445082
2022-04-08 13:55:22 -07:00
jax authors
0bfb3efcd7 [JAX] Fix batch logic for approx_min/max_k
Previous logic was copied from lax.sort and was incorrect.
Since approx_top_k can handle multi-rank tensors, the only mapping we need
is to set the reduction_dim correctly.

PiperOrigin-RevId: 440445041
2022-04-08 13:50:36 -07:00
jax authors
6cb7526390 Merge pull request #10176 from lgeiger:simplify-diagonal
PiperOrigin-RevId: 440437605
2022-04-08 13:19:09 -07:00
Peter Hawkins
0f15fa3b10 [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one.
PiperOrigin-RevId: 440433044
2022-04-08 12:57:59 -07:00
jax authors
b8602d018b Merge pull request #10198 from jakevdp:bcoo-duplicates
PiperOrigin-RevId: 440423326
2022-04-08 12:12:12 -07:00
Yash Katariya
654e5bd922 Roll forward again after the fix in the auto sharding pass.
PiperOrigin-RevId: 440412218
2022-04-08 11:25:07 -07:00
Jake VanderPlas
8b9efe79e7 [sparse] fix autodiff bug in spdot_general 2022-04-08 11:04:26 -07:00
Lukas Geiger
60c828a78a Simplify jnp.trace implementation 2022-04-08 18:10:01 +01:00
Lukas Geiger
084adc7b79 Simplify jnp.diagonal implementation 2022-04-08 18:07:32 +01:00
Peter Hawkins
8b6b736ef3 Revert: Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds.
The @platforms repository has been updated in the @tf_runtime repository, which was pulling in the old version of @platforms. We no longer need to override @platforms in the JAX WORKSPACE.

PiperOrigin-RevId: 440375016
2022-04-08 08:49:00 -07:00
Peter Hawkins
648a512488 [MHLO] Add direct MHLO lowerings for sparse primitives.
PiperOrigin-RevId: 440374054
2022-04-08 08:43:57 -07:00
jax authors
1cb4fccd1d Merge pull request #10187 from hawkinsp:jaxlib
PiperOrigin-RevId: 440362035
2022-04-08 07:41:33 -07:00
Joan Puigcerver
0c02f7935a Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests.
PiperOrigin-RevId: 440300882
2022-04-08 01:08:55 -07:00
Peter Hawkins
4dc69034b0 Update version numbers after jax/jaxlib release. 2022-04-07 16:40:19 -04:00
jax authors
58bdcb89e8 Merge pull request #10185 from hawkinsp:jaxlib
PiperOrigin-RevId: 440184287
2022-04-07 13:20:36 -07:00
Peter Hawkins
7f751c5523 Update libtpu version for jax 0.3.5 release. jax-v0.3.5 2022-04-07 16:14:13 -04:00
jax authors
28cb44e8f6 Merge pull request #10184 from jakevdp:merge-bcoo-dot-general
PiperOrigin-RevId: 440178509
2022-04-07 12:57:19 -07:00
Yash Katariya
96af4d5393 Remove sharded_jit usage from jax2tf because sharded_jit is deprecated.
PiperOrigin-RevId: 440169129
2022-04-07 12:14:10 -07:00
Jake VanderPlas
01e4fa8a78 [sparse] consolidate flavors of bcoo_dot_general 2022-04-07 11:28:12 -07:00
Lena Martens
5522ed1702 jax2tf: Support uint32 keys in rng_bit_generator.
This follows the rng_bit_generator_translation rule in JAX, which allows for
both uint32 and uint64 keys and casts between them. The default rbg prng
implementation in JAX uses a (4,) uint32 key.

PiperOrigin-RevId: 440124048
2022-04-07 09:19:22 -07:00
jax authors
02fd8752bd Add __init__ to PolyShape.
PiperOrigin-RevId: 440120323
2022-04-07 09:06:37 -07:00
jax authors
b713d3ce4b Minor change to lax to support jax2tf shape polymorphic concatenation.
PiperOrigin-RevId: 440113799
2022-04-07 08:34:27 -07:00
Peter Hawkins
cbdcdf7401 [MHLO] Add MHLO lowerings for parallel collectives.
PiperOrigin-RevId: 440106753
2022-04-07 07:59:26 -07:00
jax authors
28842151c6 Merge pull request #10167 from ROCmSoftwarePlatform:rocm_solver_api_consolidation
PiperOrigin-RevId: 439997492
2022-04-06 20:00:20 -07:00
jax authors
8b3f039252 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
Rohit Santhanam
6c560b14a7 Consolidation of hipsolver/cusolver APIs. 2022-04-07 01:46:43 +00:00
jax authors
832d9aa435 Merge pull request #10175 from jakevdp:bcoo-spdot
PiperOrigin-RevId: 439980561
2022-04-06 18:01:05 -07:00
jax authors
700d1e1e63 Merge pull request #10177 from hawkinsp:jaxlib
PiperOrigin-RevId: 439975370
jaxlib-v0.3.5
2022-04-06 17:27:38 -07:00
Peter Hawkins
96ba290faf Jax 0.3.5 and jaxlib 0.3.5 release. 2022-04-06 23:56:41 +00:00
jax authors
7ee6adb1a5 Merge pull request #10173 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 439955276
2022-04-06 15:50:10 -07:00
Jake VanderPlas
aa0da8e8e7 [sparse] make bcoo_spdot_general return a BCOO array, not raw buffers 2022-04-06 15:40:26 -07:00