4496 Commits

Author SHA1 Message Date
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Rohit Santhanam
235ea7059c [ROCm] Disable new array_interoperability dlpack tests. 2022-07-17 04:48:11 +00:00
Jake VanderPlas
2f4c485a54 Add dlpack support to device_array and jax.numpy 2022-07-15 17:31:11 -07:00
Yash Katariya
90687cc1ff Make lower_mesh_computation accept sharding instances. The new path is tested as everything in pjit goes through the new lower_sharding_computation except of AUTO and UNSPECIFIED (see below for these 2).
* Split `lower_mesh_computation` into `lower_mesh_computation` and `lower_sharding_computation`. This is because `lower_mesh_computation` handles 3 paths; `spmd lowering path`, `non-spmd lowering path` and `xmap spmd lowering path`. I didn't want to add a 4th path to it for general shardings.
  * `lower_sharding_computation` works in SPMD mode since its only used in pjit. Majority of the logic is the same. The only difference is that `mesh` does not exist in this function.

* `MeshComputation` is the point where `lower_mesh_computation` and `lower_sharding_computation` merge.

* `AUTO` and `UNSPECIFIED` cannot be used without mesh right now but I have a CL to fix this.

* Rest of the changes are to make all other functions play nicely with sharding instances.

PiperOrigin-RevId: 461260553
2022-07-15 16:16:23 -07:00
Jake VanderPlas
c1549a0a16 [sparse] make sparse objects compatible with jax.jit.lower() 2022-07-15 09:58:31 -07:00
Tom Hennigan
10720258ea Reduce the verbosity of treedef printing for custom nodes.
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.

Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:

    >>> params = { .. some params .. }
    >>> f = jax.jit(..).lower(params).compile()
    >>> f(params)  # fine
    >>> params['some_new_thing'] = something
    >>> f(params)
    TypeError: function compiled for {treedef}, called with {treedef}.

PiperOrigin-RevId: 461190971
2022-07-15 07:14:28 -07:00
George Necula
777c129dfb [dynamic-shapes] Split dynamic_api_test.py
PiperOrigin-RevId: 461109288
2022-07-14 20:18:53 -07:00
Jake VanderPlas
0f14943524 lax_numpy_test: make compatible with numpy 1.24-dev 2022-07-14 14:35:10 -07:00
jax authors
4d1c6dfefa Merge pull request #11469 from jakevdp:fix-rem-jvp
PiperOrigin-RevId: 460517781
2022-07-12 11:53:27 -07:00
Jake VanderPlas
daf6e3b065 BUG: fix jvp rule for lax.rem 2022-07-12 09:50:42 -07:00
jax authors
3eff9d11d2 Internal change
PiperOrigin-RevId: 460434859
2022-07-12 05:21:20 -07:00
Yash Katariya
0bc8f8abeb * Check if the device assignment is the same across input and output shardings.
* Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources.

PiperOrigin-RevId: 460326054
2022-07-11 16:27:11 -07:00
jax authors
11896b68a2 Merge pull request #11429 from sharadmv:for-loop
PiperOrigin-RevId: 460318883
2022-07-11 15:52:42 -07:00
Benjamin Kramer
9e16efa98a Integrate LLVM at llvm/llvm-project@71c9757474
Updates LLVM usage to match
[71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3)

PiperOrigin-RevId: 460299215
2022-07-11 14:21:09 -07:00
Sharad Vikram
9d610e2de6 Add loop invariant residual fixpoint test 2022-07-11 13:10:03 -07:00
jax authors
cc42c8091d Merge pull request #11406 from jakevdp:bcoo-add-batchdim
PiperOrigin-RevId: 460226570
2022-07-11 09:06:42 -07:00
Peter Hawkins
64e0b5d801 Increase bazel sharding of GPU tests.
Reduces the maximum time for some test shards to avoid flaky timeouts.
2022-07-11 14:19:43 +00:00
Sharad Vikram
b666f665ec Rollback of HCB GPU custom call due to internal failures
PiperOrigin-RevId: 460079787
2022-07-10 13:05:27 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Yash Katariya
09ba51f323 Move _get_array_mapping from gda.py to pxla.py
PiperOrigin-RevId: 459891853
2022-07-08 21:38:06 -07:00
jax authors
df993ea32f Merge pull request #11410 from sharadmv:for-loop
PiperOrigin-RevId: 459879694
2022-07-08 19:37:57 -07:00
Sharad Vikram
bff71b2c4f Add loop-invariant residual optimization for for 2022-07-08 18:54:51 -07:00
jax authors
66ab792fc0 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU
PiperOrigin-RevId: 459872063
2022-07-08 18:23:16 -07:00
jax authors
dac310c221 Merge pull request #11421 from jakevdp:scalar-meta-nocopy
PiperOrigin-RevId: 459823335
2022-07-08 13:30:20 -07:00
Yash Katariya
bb2c5f111a Resolve TODOs and add some more checks for the jax.Array path.
PiperOrigin-RevId: 459808511
2022-07-08 12:19:19 -07:00
YouJiacheng
7c707832aa Enable CustomCall implementation on GPU 2022-07-09 02:29:08 +08:00
Jake VanderPlas
e19df1a9bf Use asarray rather than array in ScalarMeta
Why? This will make it so that jnp.int32(x) and friends no longer insert
a gratuitous copy_p operation in the jaxpr
2022-07-08 11:16:40 -07:00
Yash Katariya
229ddecc45 * Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value.
* Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this.
  * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`.

* Made all auto sharding tests parameterized to test both gda and array.

PiperOrigin-RevId: 459776152
2022-07-08 09:45:23 -07:00
Peter Hawkins
5a7bedca37 Increase shard_count for sparse_test_gpu to 20.
1918d39765 updated the wrong test!

This test is close to the timeout in the GPU CI and flakes sometimes.

PiperOrigin-RevId: 459762867
2022-07-08 08:30:26 -07:00
jax authors
55dcbec5b5 Merge pull request #11407 from hawkinsp:minver
PiperOrigin-RevId: 459740984
2022-07-08 06:04:47 -07:00
jax authors
7ffedb5815 Merge pull request #11400 from jakevdp:deprecate-treeutil
PiperOrigin-RevId: 459681801
2022-07-07 23:05:35 -07:00
Peter Hawkins
1918d39765 Increase number of shards for GPU sparse_test to 20. 2022-07-07 21:14:25 -04:00
jax authors
a3e8ae4b1a Merge pull request #11388 from jakevdp:fix-bool-weak
PiperOrigin-RevId: 459641851
2022-07-07 17:43:17 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Jake VanderPlas
adcf30ef6b [sparse] remove deprecated bcoo_add_batch_dim utility 2022-07-07 16:57:36 -07:00
jax authors
44bd311ae7 Merge pull request #11403 from jakevdp:sparse-unary
PiperOrigin-RevId: 459634024
2022-07-07 16:57:23 -07:00
Jake VanderPlas
56d61d3f2d BUG: ensure that boolean scalars are never marked weak 2022-07-07 15:41:23 -07:00
Yash Katariya
7da733f94b Change the internals of with_sharding_constraint to use the sharding instances.
PiperOrigin-RevId: 459600050
2022-07-07 14:22:10 -07:00
Jake VanderPlas
2b4f72b6f4 [sparse] fix unary operations in presence of duplicate indices 2022-07-07 13:49:50 -07:00
jax authors
fe1bbd59dd Merge pull request #11399 from mattjj:lower-abstracted-axes
PiperOrigin-RevId: 459585916
2022-07-07 13:20:39 -07:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Yash Katariya
57ed5dc3f7 Add a util to fetch value of a GDA to host when its single-controller. Error out in McJAX
PiperOrigin-RevId: 459555907
2022-07-07 11:09:13 -07:00
Yash Katariya
2314951669 Convert everything in pjit to the Sharding interface. The following contains the things that have changed in this CL:
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
2022-07-07 10:41:52 -07:00
jax authors
2b8fbe9fe4 Merge pull request #11367 from apaszke:xmap-tracer-leak
PiperOrigin-RevId: 459456785
2022-07-07 02:01:51 -07:00
jax authors
5270cb1c1f Merge pull request #11387 from mattjj:djax-bint
PiperOrigin-RevId: 459430960
2022-07-06 23:00:59 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00