12214 Commits

Author SHA1 Message Date
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
Yash Katariya
5910cdc861 Print the repr of device
PiperOrigin-RevId: 459986696
2022-07-09 17:08:44 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f Merge pull request #11410 from sharadmv:for-loop
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f Add loop-invariant residual optimization for for 2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
Anish Tondwalkar
847c04fd8c [mhlo] CosOp -> CosineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459852934
2022-07-08 16:04:14 -07:00
Anish Tondwalkar
5f7018a62e [mhlo] SinOp -> SineOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459830783
2022-07-08 14:08:12 -07:00
jax authors
1bb1fe0658 Remove workaround for rank-0 zarr chunk layout bug in TensorStore
This has now been fixed in TensorStore.

PiperOrigin-RevId: 459824051
2022-07-08 13:35:29 -07:00
jax authors
dac310c221 Merge pull request #11421 from jakevdp:scalar-meta-nocopy
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a Resolve TODOs and add some more checks for the jax.Array path.
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
Anish Tondwalkar
a2f2d1fa42 [mhlo] ConvOp -> ConvolutionOp
Aligns the op class name with the mnemonic

PiperOrigin-RevId: 459808502
2022-07-08 12:13:51 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
jax authors
5285a15de1 Merge pull request #11419 from hawkinsp:jaxlibcleanup
PiperOrigin-RevId: 459780016
2022-07-08 10:04:39 -07:00
Yash Katariya
229ddecc45 * Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value.
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
  * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.

* Made all auto sharding tests parameterized to test both gda and array.

PiperOrigin-RevId: 459776152
2022-07-08 09:45:23 -07:00
Peter Hawkins
5a7bedca37 Increase shard_count for sparse_test_gpu to 20.
1918d39765 updated the wrong test!

This test is close to the timeout in the GPU CI and flakes sometimes.

PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
Peter Hawkins
41b015ab0c Remove stale code from jax/_src/lib/__init__.py
Remove inaccurate/stale __all__.
Remove unused alias _xla_extension_version.
2022-07-08 11:09:58 -04:00
jax authors
928f22cb6b Merge pull request #11418 from hawkinsp:bzl
PiperOrigin-RevId: 459754677
2022-07-08 07:38:13 -07:00
Peter Hawkins
a48f4e116e Change Bazel test rules to generate per-backend test suites. 2022-07-08 14:19:05 +00:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
Tamara Norman
bc9c4b77d0 Adjust docs to account for what the actual current RNG behavior is
PiperOrigin-RevId: 459712928
2022-07-08 02:55:36 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
jax authors
34fea3d496 Merge pull request #11408 from hawkinsp:sparseshard
PiperOrigin-RevId: 459647331
2022-07-07 18:23:20 -07:00
Peter Hawkins
1918d39765 Increase number of shards for GPU sparse_test to 20. 2022-07-07 21:14:25 -04:00
jax authors
a3e8ae4b1a Merge pull request #11388 from jakevdp:fix-bool-weak
PiperOrigin-RevId: 459641851
2022-07-07 17:43:17 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
jax authors
44bd311ae7 Merge pull request #11403 from jakevdp:sparse-unary
PiperOrigin-RevId: 459634024
2022-07-07 16:57:23 -07:00
Jake VanderPlas
56d61d3f2d BUG: ensure that boolean scalars are never marked weak 2022-07-07 15:41:23 -07:00
Yash Katariya
7da733f94b Change the internals of with_sharding_constraint to use the sharding instances.
PiperOrigin-RevId: 459600050
2022-07-07 14:22:10 -07:00
Jake VanderPlas
2b4f72b6f4 [sparse] fix unary operations in presence of duplicate indices 2022-07-07 13:49:50 -07:00
jax authors
fe1bbd59dd Merge pull request #11399 from mattjj:lower-abstracted-axes
PiperOrigin-RevId: 459585916
2022-07-07 13:20:39 -07:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Marc van Zee
9d18f43a01 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.

PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Yash Katariya
57ed5dc3f7 Add a util to fetch value of a GDA to host when its single-controller. Error out in McJAX
PiperOrigin-RevId: 459555907
2022-07-07 11:09:13 -07:00
Yash Katariya
2314951669 Convert everything in pjit to the Sharding interface. The following contains the things that have changed in this CL:
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
2022-07-07 10:41:52 -07:00
Peter Hawkins
88c1e7dce2 Flip after_neurips flag to True.
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
jax authors
fb7e39b13e Merge pull request #11390 from hawkinsp:distributed_init
PiperOrigin-RevId: 459518348
2022-07-07 08:23:26 -07:00
jax authors
2b8fbe9fe4 Merge pull request #11367 from apaszke:xmap-tracer-leak
PiperOrigin-RevId: 459456785
2022-07-07 02:01:51 -07:00
jax authors
5270cb1c1f Merge pull request #11387 from mattjj:djax-bint
PiperOrigin-RevId: 459430960
2022-07-06 23:00:59 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Anish Tondwalkar
5d379bba9e mhlo.rng op with distribution attr
Aligns with the XLA kRng which takes a distribution as an attribute
instead of having separate ops for each distribution.

PiperOrigin-RevId: 459389874
2022-07-06 18:03:02 -07:00
Peter Hawkins
bdbdecd458 Refactor distributed GPU device initialization.
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
jax authors
89a6766964 Merge pull request #11313 from mattjj:djax-revive-iree
PiperOrigin-RevId: 459360223
2022-07-06 15:34:05 -07:00