3274 Commits

Author SHA1 Message Date
Yash Katariya
2fc64bee13 Change the axis_resources argument of with_sharding_constraint to shardings to match pjit and jit.
PiperOrigin-RevId: 509275107
2023-02-13 10:53:57 -08:00
Jake VanderPlas
58323d5b40 jax.numpy reductions: better validation of initial value 2023-02-13 08:43:25 -08:00
Yash Katariya
6caaffc20c Add in_shardings and out_shardings argument to pjit and jit to start deprecating in_axis_resources and out_axis_resources.
PiperOrigin-RevId: 508934327
2023-02-11 15:30:14 -08:00
Peter Hawkins
612a940160 Minimize the set of names exported from jax.experimental.pjit.
PiperOrigin-RevId: 508889911
2023-02-11 07:37:32 -08:00
Yash Katariya
9316188b3a [Rollback] Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508827908
2023-02-10 21:36:56 -08:00
Skye Wanderman-Milne
e54858522c Add back loading TPU plugin for older jaxlib versions.
This was removed in 668b82d529.

PiperOrigin-RevId: 508777939
2023-02-10 16:16:20 -08:00
jax authors
fc507f2ebe Merge pull request #14418 from mattjj:vmap-spmd-axis-name-tuples
PiperOrigin-RevId: 508777043
2023-02-10 16:08:32 -08:00
Yash Katariya
0d07372995 Point to the exact primitive name nested under jit/pjit instead of mentioning all possible ones.
PiperOrigin-RevId: 508770290
2023-02-10 15:40:25 -08:00
Parker Schuh
568a93bcd1 Convert _arrays to return PyArray instead of PyBuffer.
PiperOrigin-RevId: 508769390
2023-02-10 15:32:57 -08:00
Matthew Johnson
9538bc3e73 generalize vmap spmd_axis_name to accept tuples of axis names
This brings the argument more in line with what can appear as positional
arguments to the PartitionSpec constructor.
2023-02-10 15:25:23 -08:00
jax authors
dc6bf9b725 Merge pull request #14408 from lucashofer:scipy_spence
PiperOrigin-RevId: 508756972
2023-02-10 14:36:15 -08:00
Yash Katariya
1526c3e20c Improve the error message which is raised from _get_and_check_device_assignment.
Before:

```
ValueError: Devices of all `Array` inputs and outputs should be the same. Got array device ids [0] on platform CPU and another array's device ids [0, 1, 2, 3] on platform CPU
```

After:

```
ValueError: Received incompatible devices for jitted computation. Got argument inp of ArrayPjitTest.test_jit_with_sharding_constraint_committed_inp_error.<locals>.sharded_inp with bfloat16[8,2] and device ids [0] on platform CPU and with_sharding_constraint or nested pjit or shard_map with device ids [0, 1, 2, 3] on platform CPU at jax/tests/pjit_test.py:2509 (sharded_inp)
```
PiperOrigin-RevId: 508746961
2023-02-10 13:54:15 -08:00
Lucas Hofer
4636276214 added scipy special spence
added dtype to arrays in the _spence_poly function
2023-02-10 20:33:47 +00:00
jax authors
57900d7ef2 Merge pull request #14364 from jakevdp:fix-tril-indices
PiperOrigin-RevId: 508723970
2023-02-10 12:25:06 -08:00
Jake VanderPlas
60256df668 [typing] define additional methods & properties on jax.Array
These are the methods that are only valid for actual materialized arrays (i.e. not Tracers)
In order to simplify the experience for users, we want to maintain only a single jax.Array
type, so we define all methods here and raise explicit errors on Tracer instances.
2023-02-10 09:42:32 -08:00
Peter Hawkins
54ff78dbde Deprecate jax.interpreters.xla.Device and jax.interpreters.xla.DeviceArray.
PiperOrigin-RevId: 508502470
2023-02-09 16:11:48 -08:00
Roy Frostig
1c84e4a753 migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
2023-02-09 15:11:57 -08:00
Jieying Luo
668b82d529 [PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
2023-02-09 14:33:46 -08:00
Peter Hawkins
0c14e9ab49 Change jax.ad, jax.xla, jax.pxla to point to the shims instead of the internal modules.
Don't hide _deprecations in shim modules, since it's handy for users to override deprecations locally, e.g., to verify there are no remaining users.

Fix some overly-strict type annotations.

PiperOrigin-RevId: 508461199
2023-02-09 13:31:40 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
Peter Hawkins
8268cd562d Add infrastructure for managing deprecations.
Use it to deprecate jax.experimental.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.interpreters.pxla.Mesh.

PiperOrigin-RevId: 508349776
2023-02-09 05:48:40 -08:00
jax authors
ccb974a150 Merge pull request #14370 from jakevdp:argpartition-impl
PiperOrigin-RevId: 508194466
2023-02-08 15:10:50 -08:00
Peter Hawkins
a28b01243b Move contents of jax.monitoring to jax._src.monitoring.
PiperOrigin-RevId: 508191560
2023-02-08 15:03:22 -08:00
Yash Katariya
7350f00acd Remove jax_experimental_subjaxpr_lowering_cache since it was only for jit and was False by default. Now that jit/pjit are merged, this cache is not needed since pjit does the caching and we get it for free.
PiperOrigin-RevId: 508191408
2023-02-08 14:55:56 -08:00
Jake VanderPlas
4fbaee5920 Implement jax.numpy.argpartition 2023-02-08 14:41:39 -08:00
Peter Hawkins
cc8d7fae32 Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
2023-02-08 14:39:01 -08:00
jax authors
3e349c7bed Merge pull request #14361 from jakevdp:doc-topk
PiperOrigin-RevId: 508181335
2023-02-08 14:19:01 -08:00
Yash Katariya
6ec9082cf5 Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.
This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
2023-02-08 11:55:48 -08:00
Jake VanderPlas
794557d349 tril_indices/triu_indices: fix call signature & add type annotations 2023-02-08 11:19:06 -08:00
Jake VanderPlas
3c6183498a lax.top_k: improve documentation and errors on invalid values 2023-02-08 11:07:56 -08:00
jax authors
4358b803e9 Merge pull request #14355 from jakevdp:tril-indices
PiperOrigin-RevId: 508119785
2023-02-08 10:35:10 -08:00
Yash Katariya
7b1128fdc4 Use jnp.arange to break the pjit cache (when jit and pjit are merged) because pytest runs tests non-hermetically.
PiperOrigin-RevId: 508114498
2023-02-08 10:17:37 -08:00
Jake VanderPlas
a76a024548 tril/triu_indices: compute arrays at runtime 2023-02-08 09:52:41 -08:00
Roy Frostig
55c2b6dad6 move jax.interpreters.batching to jax._src.interpreters.batching
Re-export roughly all of the same symbols via `jax.interpreters.batching` for now.

PiperOrigin-RevId: 508107044
2023-02-08 09:51:00 -08:00
Skye Wanderman-Milne
eb13c053e9 Add option to run tests with persistent compilation cache enabled.
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.

PiperOrigin-RevId: 507898535
2023-02-07 15:15:31 -08:00
Peter Hawkins
6860cb8d2a Move jax.interpreters.xla to jax._src.interpreters.xla.
Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
2023-02-07 15:01:32 -08:00
Peter Hawkins
98b75cf27b Prune accidental exports from jax.interpreters.pxla.
These imports do not appear to have users outside JAX itself.

PiperOrigin-RevId: 507835295
2023-02-07 11:16:42 -08:00
jax authors
d4f422f608 Merge pull request #14303 from carlosgmartin:rankdata
PiperOrigin-RevId: 507805953
2023-02-07 09:37:04 -08:00
carlosgmartin
8251957025 Added scipy.stats.rankdata 2023-02-07 12:07:00 -05:00
Roy Frostig
219723c738 migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
2023-02-06 22:52:36 -08:00
Yash Katariya
c252162821 Make pjit's cache global just like jit's cache. This will allow cache hits in C++ when pjit(f)(jnp.arange(3.)) is executed twice.
Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit.

PiperOrigin-RevId: 507662634
2023-02-06 20:35:26 -08:00
Skye Wanderman-Milne
6cef0873e8 Don't write executables with host callbacks to persistent compilation cache.
The persistent compilation cache can't de/serialize the callback functions (yet?).

PiperOrigin-RevId: 507628297
2023-02-06 17:37:32 -08:00
Peter Hawkins
08ff7f4ea9 Prune accidentally exported names from jax.interpreters.ad.
PiperOrigin-RevId: 507584433
2023-02-06 14:36:44 -08:00
Peter Hawkins
38a59a313b Move jax.interpreters.pxla to jax._src.interpreters.pxla.
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.

PiperOrigin-RevId: 507584264
2023-02-06 14:29:10 -08:00
Peter Hawkins
3d9ae6b467 Add a .cost_analysis() on lowered but uncompiled computations.
Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it.

PiperOrigin-RevId: 507560058
2023-02-06 12:57:57 -08:00
Yash Katariya
8a69444ff9 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43
PiperOrigin-RevId: 507520956
2023-02-06 10:37:33 -08:00
jax authors
63e0e0fdb6 Merge pull request #14291 from sharadmv:fix-checkify-caching
PiperOrigin-RevId: 507504176
2023-02-06 09:39:07 -08:00
Yash Katariya
a30ba83db2 Fix the latest jax jaxlib on pypi failure
PiperOrigin-RevId: 507208172
2023-02-04 20:16:33 -08:00
Yash Katariya
973bdb203b Copy the jit docs and paste it inside the new jit fork.
PiperOrigin-RevId: 507161252
2023-02-04 12:34:35 -08:00
Sharad Vikram
c231171fb6 Fix checkify caching with nested call primitives 2023-02-03 23:28:37 -08:00