255 Commits

Author SHA1 Message Date
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Parker Schuh
c3e6d5cb2a Remove some differences between jit and pjit.
- MaybeCollectGarbage
- The recursive check.
- DevicePut for np arrays and scalars when device_count == 1.

PiperOrigin-RevId: 507972281
2023-02-07 21:33:07 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Peter Hawkins
3d9ae6b467 Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
Yash Katariya
be67db33d8 Skip testAutodiffCache test if xla_extension_version < 123
PiperOrigin-RevId: 507292333
2023-02-05 09:39:36 -08:00
Yash Katariya
f445c84ba4 Add support for a list of allow_spmd_sharding_propagation_to_output. This gives us more flexibility to tell SPMD which shardings to override.
PiperOrigin-RevId: 507035958
2023-02-03 17:59:10 -08:00
jax authors
0affb3bb18 Merge pull request #14283 from pschuh:static_argnums_custom_partitioning
PiperOrigin-RevId: 507005561
2023-02-03 15:14:08 -08:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Parker Schuh
7526d0ea1f Add static_argnums to custom_partitioning.
Arguments specified by static_argnums cannot contain
any jax tracers because they will be passed into the XLA compiler
where the lowering information for these tracers is already lost.
2023-02-03 11:41:17 -08:00
jax authors
0f289ab0e3 Merge pull request #14174 from google:pjrt_test
PiperOrigin-RevId: 506751529
2023-02-02 16:23:26 -08:00
Yash Katariya
8a4de1f86a Remove the usage of _arrays from tests
PiperOrigin-RevId: 505871063
2023-01-30 20:02:37 -08:00
Jake VanderPlas
43e57db77a Begin deprecation of public jax.ShapedArray 2023-01-30 11:27:58 -08:00
Skye Wanderman-Milne
93cd07efb8 Add PJRT C API to Cloud TPU test matrix
Also shortens the job names so the full name is visible from the
github UI (this was driving me crazy), and marks a new test that can't
be run on the PJRT C API yet.

Example run: https://github.com/google/jax/actions/runs/4019968334
2023-01-27 01:06:21 +00:00
Skye Wanderman-Milne
49e751b4ad Add warning filter to ArrayPjitTest.test_pmap_pjit_axis_index 2023-01-26 21:30:28 +00:00
Yash Katariya
0846aebf63 Add axis_substitution_rules rule for pmap so that pjit(pmap) with an axis_index works properly
PiperOrigin-RevId: 504837464
2023-01-26 07:33:15 -08:00
Yash Katariya
18eca1a479 Add disable_jit support to pjit.cc
PiperOrigin-RevId: 504067752
2023-01-23 13:31:39 -08:00
Yash Katariya
864d640ee1 Set committed=True for nested pjits/with_sharding_constraint if any jaxpr_sharding is not UNSPECIFIED.
PiperOrigin-RevId: 503833657
2023-01-22 14:07:03 -08:00
Matthew Johnson
358775f901 update pjit test 2023-01-20 11:40:22 -08:00
Yash Katariya
5714616dd6 Set no_kwargs to False because pjit supports kwargs
PiperOrigin-RevId: 503019556
2023-01-18 17:14:24 -08:00
Parker Schuh
b58dd3cbe1 Add support for __signature__ to PjitFunction.
PiperOrigin-RevId: 502731453
2023-01-17 17:28:14 -08:00
jax authors
86ba62d05e Merge pull request #13991 from skye:pjrt_c_api_marker
PiperOrigin-RevId: 501909180
2023-01-13 12:15:00 -08:00
Yash Katariya
c8ad89e358 Make jit a thin wrapper around pjit which ignores the mesh context manager (just like how it is today)
Pass `None` as the resource_env via `jit` because `jit(pjit)` will ignore the outer mesh because `jit` will set the resource env to empty mesh.

This does not make `jit` and `pjit` the same API but it shares all the code between both the APIs (cpp and python) while preserving the current semantics of both `jit` and `pjit`.

PiperOrigin-RevId: 501707496
2023-01-12 17:24:32 -08:00
Skye Wanderman-Milne
c0577f70f9 Migrate pytestmark usage to new @jtu.pytest_mark_if_available decorator.
See discussion in https://github.com/google/jax/pull/13977. Marking
entire modules is magical and verbose, plus less precise than marking
individual classes or tests.

I wasn't super careful on which classes to mark, and erred on the side
of marking too many classes (in line with the previous behavior). It's
possible some test classes don't actually benefit from multiple
accelerators.
2023-01-12 22:44:39 +00:00
Skye Wanderman-Milne
f90b5eed52 Add pjrt_c_api_unimplemented pytest marker to skip unsupported tests.
Also adds `test_util.pytest_mark_if_available` helper function.
2023-01-12 22:17:23 +00:00
Yash Katariya
44b97ae3f6 Fix pjit's initial style usage of consts.
Instead of smuggling them via the jaxpr, pull it out and pass them with args. This is because consts can be tracers and that fails down the stack when lowering to mlir.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 500544141
2023-01-08 10:38:08 -08:00
Yash Katariya
5afebba285 Remove _global_avals from infer_params because everything is global in pjit after jax.Array was enabled.
PiperOrigin-RevId: 500012042
2023-01-06 00:08:16 -08:00
Yash Katariya
711c3da195 Reshard pmap unconditionally if arguments with PmapSharding are passed to pjit. This is to support all the jit use cases with pjit to merge their API.
PiperOrigin-RevId: 499338100
2023-01-03 16:09:05 -08:00
Parker Schuh
9674b063c3 Add static_argnames to the _cpp_pjit path.
PiperOrigin-RevId: 499311688
2023-01-03 14:05:52 -08:00
Eugene Burmako
a1480c454e Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
2022-12-27 08:53:20 -08:00
Peter Hawkins
2c6c30d458 Bump the minimum jaxlib version to 0.4.1.
Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39.
2022-12-19 17:49:24 +00:00
Yash Katariya
4b587fa1f0 Move pjit.py to jax/_src in preparation for merging the jit and pjit frontend APIs
PiperOrigin-RevId: 495944279
2022-12-16 13:07:15 -08:00
Yash Katariya
3b35e9811d [Roll forward with fixes attempt 2] Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.
PiperOrigin-RevId: 495886977
2022-12-16 09:03:24 -08:00
Eugene Burmako
b8ae8e3fa1 (NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
2022-12-15 21:00:07 -08:00
Yash Katariya
ecaa215043 [Rollback 2] Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.
PiperOrigin-RevId: 495756613
2022-12-15 19:26:25 -08:00
Yash Katariya
3b9088f9a3 Add support for inline to pjit. This is to merge the jit and pjit frontend API.
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 495726005
2022-12-15 16:26:27 -08:00
Yash Katariya
5d69b7194a Fix the test failing in the continuous builds on GPU. There are 4 devices so the index of the last device is 3.
PiperOrigin-RevId: 495461803
2022-12-14 17:28:11 -08:00
Yash Katariya
7d4ef891af Add device and backend API to pjit but resolve them away in infer_params. This is to merge jit and pjit frontend API.
The semantics of mentioning `device` or `backend` on `pjit` is the same as doing a `device_put` i.e. no matter which device the arg is on, reshard it to the device mentioned.

PiperOrigin-RevId: 495437165
2022-12-14 15:41:59 -08:00
Yash Katariya
64b6efc680 [Roll-forward with fixes] Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.
PiperOrigin-RevId: 495414822
2022-12-14 14:18:11 -08:00
Yash Katariya
82ca823956 [Rollback] Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.
PiperOrigin-RevId: 495179106
2022-12-13 18:31:09 -08:00
Yash Katariya
08b6b0dd43 Add keep_unused to pjit's API as a step to merge jit and pjit frontend API.
PiperOrigin-RevId: 495137581
2022-12-13 15:08:25 -08:00
Yash Katariya
048d133590 Support static_argnames in pjit as a first step to merge jit and pjit.
Also add support for `kwargs` only if `in_axis_resources` is unspecified.

PiperOrigin-RevId: 495117879
2022-12-13 13:57:30 -08:00
Yash Katariya
13c34f9dc5 Move with_sharding_constraint out of experimental into jax.lax namespace.
PiperOrigin-RevId: 494635809
2022-12-11 22:55:21 -08:00
Yash Katariya
e7e9687161 Allow pjit's C++ dispatch path to operate on uncommitted array only if it belongs on a single device. This will bring pjit's dispatch performance in line with jit to prepare for jit/pjit frontend merge.
PiperOrigin-RevId: 493164446
2022-12-05 18:09:59 -08:00
Hyeontaek Lim
02fab525a7 Add tests to check if pjit handles deleted array inputs gracefully and consistently
pjit dispatch paths should check deleted array inputs when attempting to use
them. These new tests ensure that various pjit dispatch paths detect and handle
them gracefully and consistently.

Add a check to the PyArray argument handling to make the tests pass.

PiperOrigin-RevId: 492605524
2022-12-02 18:41:31 -08:00
Yash Katariya
934bc4e1b3 Move PartitionSpec and Mesh out of experimental and into the sharding namespace. The new API endpoint is jax.sharding.PartitionSpec and jax.sharding.Mesh.
PiperOrigin-RevId: 492358238
2022-12-01 19:28:32 -08:00
Yash Katariya
621322858d Fix vmap(jvp(pjit(f))) when pjit doesn't have any axis_resources
PiperOrigin-RevId: 492238366
2022-12-01 10:42:26 -08:00
Johannes Reifferscheid
cc1d2aaaed Disable more {cost,memory}_analysis tests when MLIR lowering is enabled.
PiperOrigin-RevId: 490898616
2022-11-25 06:25:56 -08:00
jax authors
dd902fde21 Merge pull request #13317 from google:xdist_tpu
PiperOrigin-RevId: 490366370
2022-11-22 16:40:00 -08:00