12232 Commits

Author SHA1 Message Date
Yash Katariya
7da733f94b Change the internals of with_sharding_constraint to use the sharding instances.
PiperOrigin-RevId: 459600050
2022-07-07 14:22:10 -07:00
Jake VanderPlas
2b4f72b6f4 [sparse] fix unary operations in presence of duplicate indices 2022-07-07 13:49:50 -07:00
jax authors
fe1bbd59dd Merge pull request #11399 from mattjj:lower-abstracted-axes
PiperOrigin-RevId: 459585916
2022-07-07 13:20:39 -07:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Marc van Zee
9d18f43a01 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None).
Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together.

PiperOrigin-RevId: 459566727
2022-07-07 11:54:39 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Yash Katariya
57ed5dc3f7 Add a util to fetch value of a GDA to host when its single-controller. Error out in McJAX
PiperOrigin-RevId: 459555907
2022-07-07 11:09:13 -07:00
Yash Katariya
2314951669 Convert everything in pjit to the Sharding interface. The following contains the things that have changed in this CL:
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
2022-07-07 10:41:52 -07:00
Peter Hawkins
88c1e7dce2 Flip after_neurips flag to True.
PiperOrigin-RevId: 459541278
2022-07-07 10:12:15 -07:00
jax authors
fb7e39b13e Merge pull request #11390 from hawkinsp:distributed_init
PiperOrigin-RevId: 459518348
2022-07-07 08:23:26 -07:00
jax authors
2b8fbe9fe4 Merge pull request #11367 from apaszke:xmap-tracer-leak
PiperOrigin-RevId: 459456785
2022-07-07 02:01:51 -07:00
jax authors
5270cb1c1f Merge pull request #11387 from mattjj:djax-bint
PiperOrigin-RevId: 459430960
2022-07-06 23:00:59 -07:00
Matthew Johnson
98e71fe31d [dynamic-shapes] revive basic bounded int machinery, add tests 2022-07-06 22:31:26 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Anish Tondwalkar
5d379bba9e mhlo.rng op with distribution attr
Aligns with the XLA kRng which takes a distribution as an attribute
instead of having separate ops for each distribution.

PiperOrigin-RevId: 459389874
2022-07-06 18:03:02 -07:00
Peter Hawkins
bdbdecd458 Refactor distributed GPU device initialization.
Avoid reregistering backend factories; instead simply have the usual
factory function support distributed GPU.
2022-07-07 00:45:19 +00:00
jax authors
89a6766964 Merge pull request #11313 from mattjj:djax-revive-iree
PiperOrigin-RevId: 459360223
2022-07-06 15:34:05 -07:00
Matthew Johnson
6bb90fde9e [dynamic shapes] revive iree 2022-07-06 15:01:16 -07:00
jax authors
638e4353e6 Merge pull request #11381 from bartvm:main
PiperOrigin-RevId: 459346579
2022-07-06 14:40:42 -07:00
Peter Hawkins
95e79332c0 Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel.
Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU).

PiperOrigin-RevId: 459320212
2022-07-06 12:51:43 -07:00
jax authors
354c684873 [jax2tf] Update docs for supported convolution types.
PiperOrigin-RevId: 459316769
2022-07-06 12:36:29 -07:00
Bart van Merriënboer
de08344cb7 Avoid casting input to _fft_helper. 2022-07-06 14:29:54 -04:00
Robert Suderman
4ed8255d46 Fix iree.py python integration for backend changes
CPU / VMVX runtime is now called local-task. Updated to
separate compiler, runtime, and backend naming for single
specified configuration.

PiperOrigin-RevId: 459298179
2022-07-06 11:17:44 -07:00
jax authors
da5385f235 Merge pull request #11379 from hawkinsp:parallel
PiperOrigin-RevId: 459276185
2022-07-06 09:53:32 -07:00
Peter Hawkins
4443705e0f Add script for parallel accelerator testing under Bazel. 2022-07-06 10:58:04 -04:00
jax authors
b5e6145a42 Merge pull request #11359 from hawkinsp:bazel
PiperOrigin-RevId: 459234031
2022-07-06 06:13:20 -07:00
Peter Hawkins
1c75eee1ff Document how to run tests using Bazel.
* Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib.
* Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib.
2022-07-06 08:30:35 -04:00
Adam Paszke
5777c1eac2 Add support for post_process of xmap in BatchTrace
PiperOrigin-RevId: 459108183
2022-07-05 12:07:26 -07:00
jax authors
0719f986aa Merge pull request #11368 from gnecula:shape_poly_test_refactor
PiperOrigin-RevId: 459057179
2022-07-05 05:19:05 -07:00
George Necula
dc3d776311 [shape_poly] Refactor tests to separate the vmap tests
Introduce ShapePolyVmapPrimitivesTest to contain all the tests
that vmap results in batch polymprphic code.

Also fix some warnings about eig, eigh, and qr taking only kwarg
arguments.
2022-07-05 14:01:19 +02:00
Adam Paszke
7439e1b1f8 Properly count sublevels when tracing xmap body
Otherwise it can lead to tracer leak errors. I'm not a 100% sure how
this works out, because the sublevel counting has changed since I read
it previously. This replicates the changes applied to
DynamicJaxprTrace.process_map since I last looked at it.
2022-07-05 11:43:26 +00:00
jax authors
b2d70058c7 Merge pull request #9426 from gnecula:iree_poly
PiperOrigin-RevId: 459042352
2022-07-05 03:28:21 -07:00
George Necula
b6c90693c6 Fix mypy annotations 2022-07-05 12:49:37 +03:00
George Necula
5983d385da [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota.
Also add more tests.
2022-07-05 12:14:15 +03:00
jax authors
5d6f81cda8 Merge pull request #11361 from hawkinsp:tri
PiperOrigin-RevId: 458800381
2022-07-03 15:32:47 -07:00
Peter Hawkins
56202647bc Add missing dtype canonicalization to tridiagonal solve lowering.
This meant that the tridiagonal solve test failed when X64 mode was disabled on GPU.
2022-07-03 16:08:54 -04:00
jax authors
a4798c32bd Merge pull request #11358 from nalzok:patch-1
PiperOrigin-RevId: 458786037
2022-07-03 12:38:14 -07:00
jax authors
8f16270dca Merge pull request #11360 from hawkinsp:tol
PiperOrigin-RevId: 458786025
2022-07-03 12:32:37 -07:00
Peter Hawkins
62a392a7e2 Relax test tolerances.
These tests current fail on M1 Mac.
2022-07-03 15:13:28 -04:00
Qingyao Sun
2d063d3f85
Fix typos in omnistaging.md 2022-07-03 19:02:30 +00:00
Haoyu Zhang
118db407f2 Only check preemption sync point if distributed.global_state.client is initialized.
PiperOrigin-RevId: 458670810
2022-07-02 11:49:42 -07:00
Roy Frostig
1e875dddc2 handle unimplemented hlo_modules() on XLA executables
PiperOrigin-RevId: 458609175
2022-07-01 23:40:52 -07:00
jax authors
fe2edb537c Merge pull request #11344 from sharadmv:for-loop
PiperOrigin-RevId: 458592285
2022-07-01 20:23:35 -07:00
Yash Katariya
8a23605462 Add a limiter for in-flight bytes. Read a shard from TensorStore if there are enough bytes are available. This only works for deserialization right now.
PiperOrigin-RevId: 458586521
2022-07-01 19:26:59 -07:00
Roy Frostig
f12af93258 refactor stages types, adding methods for text and for cost/memory analyses
Re-organizing things this way in order to:

* Clarify internally what a lowering and executable should do, rather than what current XLA-backed versions happen to provide.

* Document that some features (e.g. cost analysis) are best-effort and intended mainly for debugging purposes. They may be unimplemented on some backends and what they return is intentionally undefined.

For an example of the latter item, this change adds a `cost_analysis()` method on `jax.stages.Compiled`. However, the expression `jit(f).lower(*args).compile().cost_analysis()` may return `None` depending on backend. Otherwise, guarantees about its output and return type are very limited -- these can differ across invocations and across JAX/jaxlib versions.

Some specifics:
* Introduce `cost_analysis` and `memory_analysis` methods on `Compiled` that do as their name suggests.
* Introduce `as_text` methods on `Lowered` and `Compiled` that do as the name suggests.
* Rename `_src.stages.Computation` protocol to `_src.stages.Lowering`.
* Fix a handful of type annotations, add various docstrings and comments explaining the above.

PiperOrigin-RevId: 458574166
2022-07-01 17:35:53 -07:00
Sharad Vikram
a82047dd4a Add partial_eval rule for for
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-07-01 17:20:14 -07:00
Peter Hawkins
1fc9afd03a Add support for running JAX tests under Bazel.
This is an alternative method for running the tests that some users may prefer: pytest is and will remain fully supported.

To use this, one creates a .bazelrc by running the existing `build.py` script, and then one can run the tests by running:
```
bazel test -c opt //tests/...
```

Issue #7323

PiperOrigin-RevId: 458551208
2022-07-01 15:07:22 -07:00
Peter Hawkins
270f73e346 Internal-only change.
PiperOrigin-RevId: 458538018
2022-07-01 13:54:40 -07:00
Haoyu Zhang
3fc24ceb35 Save an on-demand checkpoint when any worker receives a preemption signal.
PiperOrigin-RevId: 458525108
2022-07-01 12:45:30 -07:00